提交 b4aefca5 作者: imClumsyPanda

add stream support to cli_demo.py

上级 88ab9a1d
...@@ -2,7 +2,6 @@ from langchain.chains import RetrievalQA ...@@ -2,7 +2,6 @@ from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.document_loaders import UnstructuredFileLoader from langchain.document_loaders import UnstructuredFileLoader
from models.chatglm_llm import ChatGLM from models.chatglm_llm import ChatGLM
import sentence_transformers import sentence_transformers
...@@ -34,22 +33,20 @@ def load_file(filepath): ...@@ -34,22 +33,20 @@ def load_file(filepath):
docs = loader.load_and_split(text_splitter=textsplitter) docs = loader.load_and_split(text_splitter=textsplitter)
return docs return docs
def generate_prompt(related_docs: List[str],
def get_relevant_documents(self, query: str) -> List[Document]: query: str,
if self.search_type == "similarity": prompt_template=PROMPT_TEMPLATE) -> str:
docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs) context = "\n".join([doc.page_content for doc in related_docs])
for doc in docs: prompt = prompt_template.replace("{question}", query).replace("{context}", context)
doc[0].metadata["score"] = doc[1] return prompt
docs = [doc[0] for doc in docs]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
def get_docs_with_score(docs_with_score):
docs=[]
for doc, score in docs_with_score:
doc.metadata["score"] = score
docs.append(doc)
return docs
class LocalDocQA: class LocalDocQA:
llm: object = None llm: object = None
...@@ -73,8 +70,6 @@ class LocalDocQA: ...@@ -73,8 +70,6 @@ class LocalDocQA:
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
model_kwargs={'device': embedding_device}) model_kwargs={'device': embedding_device})
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
# device=embedding_device)
self.top_k = top_k self.top_k = top_k
def init_knowledge_vector_store(self, def init_knowledge_vector_store(self,
...@@ -134,34 +129,30 @@ class LocalDocQA: ...@@ -134,34 +129,30 @@ class LocalDocQA:
def get_knowledge_based_answer(self, def get_knowledge_based_answer(self,
query, query,
vs_path, vs_path,
chat_history=[], ): chat_history=[],
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 streaming=True):
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 self.llm.streaming = streaming
已知内容:
{context}
问题:
{question}"""
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
self.llm.history = chat_history
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
vs_r = vector_store.as_retriever(search_type="mmr", related_docs_with_score = vector_store.similarity_search_with_score(query,
search_kwargs={"k": self.top_k}) k=self.top_k)
# VectorStoreRetriever.get_relevant_documents = get_relevant_documents related_docs = get_docs_with_score(related_docs_with_score)
knowledge_chain = RetrievalQA.from_llm( prompt = generate_prompt(related_docs, query)
llm=self.llm,
retriever=vs_r,
prompt=prompt
)
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
knowledge_chain.return_source_documents = True if streaming:
result = knowledge_chain({"query": query}) for result, history in self.llm._call(prompt=prompt,
self.llm.history[-1][0] = query history=chat_history):
return result, self.llm.history history[-1] = list(history[-1])
history[-1][0] = query
response = {"query": query,
"result": result,
"source_documents": related_docs}
yield response, history
else:
result, history = self.llm._call(prompt=prompt,
history=chat_history)
history[-1] = list(history[-1])
history[-1][0] = query
response = {"query": query,
"result": result,
"source_documents": related_docs}
return response, history
...@@ -28,10 +28,16 @@ if __name__ == "__main__": ...@@ -28,10 +28,16 @@ if __name__ == "__main__":
history = [] history = []
while True: while True:
query = input("Input your question 请输入问题:") query = input("Input your question 请输入问题:")
resp, history = local_doc_qa.get_knowledge_based_answer(query=query, last_print_len = 0
vs_path=vs_path, for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
chat_history=history) vs_path=vs_path,
chat_history=history,
streaming=True):
print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"])
if REPLY_WITH_SOURCE: if REPLY_WITH_SOURCE:
print(resp) source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
else: # f"""相关度:{doc.metadata['score']}\n\n"""
print(resp["result"]) for inum, doc in
enumerate(resp["source_documents"])]
print("\n\n" + "\n\n".join(source_text))
...@@ -5,7 +5,8 @@ from langchain.llms.utils import enforce_stop_tokens ...@@ -5,7 +5,8 @@ from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel, AutoConfig from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch import torch
from configs.model_config import LLM_DEVICE from configs.model_config import LLM_DEVICE
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from typing import Dict, Tuple, Union, Optional from typing import Dict, Tuple, Union, Optional
DEVICE = LLM_DEVICE DEVICE = LLM_DEVICE
...@@ -54,10 +55,12 @@ class ChatGLM(LLM): ...@@ -54,10 +55,12 @@ class ChatGLM(LLM):
max_token: int = 10000 max_token: int = 10000
temperature: float = 0.01 temperature: float = 0.01
top_p = 0.9 top_p = 0.9
history = [] # history = []
tokenizer: object = None tokenizer: object = None
model: object = None model: object = None
history_len: int = 10 history_len: int = 10
streaming: bool = True
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -68,46 +71,45 @@ class ChatGLM(LLM): ...@@ -68,46 +71,45 @@ class ChatGLM(LLM):
def _call(self, def _call(self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, history: List[List[str]] = [],
stream=True) -> str: stop: Optional[List[str]] = None) -> str:
if stream: if self.streaming:
self.history = self.history + [[None, ""]] history = history + [[None, ""]]
for response, history in self.model.stream_chat( for stream_resp, history in self.model.stream_chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=self.history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
): ):
torch_gc() yield stream_resp, history
self.history[-1][-1] = response
yield response
else: else:
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=self.history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) )
torch_gc() torch_gc()
if stop is not None: if stop is not None:
response = enforce_stop_tokens(response, stop) response = enforce_stop_tokens(response, stop)
self.history = self.history + [[None, response]] history = history + [[None, response]]
return response return response, history
def chat(self, # def chat(self,
prompt: str) -> str: # prompt: str) -> str:
response, _ = self.model.chat( # response, _ = self.model.chat(
self.tokenizer, # self.tokenizer,
prompt, # prompt,
history=self.history[-self.history_len:] if self.history_len > 0 else [], # history=self.history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, # max_length=self.max_token,
temperature=self.temperature, # temperature=self.temperature,
) # )
torch_gc() # torch_gc()
self.history = self.history + [[None, response]] # self.history = self.history + [[None, response]]
return response # return response
def load_model(self, def load_model(self,
model_name_or_path: str = "THUDM/chatglm-6b", model_name_or_path: str = "THUDM/chatglm-6b",
...@@ -149,7 +151,13 @@ class ChatGLM(LLM): ...@@ -149,7 +151,13 @@ class ChatGLM(LLM):
else: else:
from accelerate import dispatch_model from accelerate import dispatch_model
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs).half() model = (
AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=True,
config=model_config,
**kwargs)
.half())
# 可传入device_map自定义每张卡的部署情况 # 可传入device_map自定义每张卡的部署情况
if device_map is None: if device_map is None:
device_map = auto_configure_device_map(num_gpus) device_map = auto_configure_device_map(num_gpus)
...@@ -160,7 +168,8 @@ class ChatGLM(LLM): ...@@ -160,7 +168,8 @@ class ChatGLM(LLM):
AutoModel.from_pretrained( AutoModel.from_pretrained(
model_name_or_path, model_name_or_path,
config=model_config, config=model_config,
trust_remote_code=True) trust_remote_code=True,
**kwargs)
.float() .float()
.to(llm_device) .to(llm_device)
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论