提交 0e8cc0d1 作者: imClumsyPanda

add streaming option in configs/model_config.py

上级 2ebcd136
...@@ -116,10 +116,12 @@ class LocalDocQA: ...@@ -116,10 +116,12 @@ class LocalDocQA:
llm_history_len: int = LLM_HISTORY_LEN, llm_history_len: int = LLM_HISTORY_LEN,
llm_model: str = LLM_MODEL, llm_model: str = LLM_MODEL,
llm_device=LLM_DEVICE, llm_device=LLM_DEVICE,
streaming=STREAMING,
top_k=VECTOR_SEARCH_TOP_K, top_k=VECTOR_SEARCH_TOP_K,
use_ptuning_v2: bool = USE_PTUNING_V2 use_ptuning_v2: bool = USE_PTUNING_V2
): ):
self.llm = ChatGLM() self.llm = ChatGLM()
self.llm.streaming = streaming
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
llm_device=llm_device, llm_device=llm_device,
use_ptuning_v2=use_ptuning_v2) use_ptuning_v2=use_ptuning_v2)
...@@ -186,9 +188,7 @@ class LocalDocQA: ...@@ -186,9 +188,7 @@ class LocalDocQA:
def get_knowledge_based_answer(self, def get_knowledge_based_answer(self,
query, query,
vs_path, vs_path,
chat_history=[], chat_history=[]):
streaming=True):
self.llm.streaming = streaming
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
vector_store.chunk_size=self.chunk_size vector_store.chunk_size=self.chunk_size
...@@ -197,7 +197,7 @@ class LocalDocQA: ...@@ -197,7 +197,7 @@ class LocalDocQA:
related_docs = get_docs_with_score(related_docs_with_score) related_docs = get_docs_with_score(related_docs_with_score)
prompt = generate_prompt(related_docs, query) prompt = generate_prompt(related_docs, query)
if streaming: if self.llm.streaming:
for result, history in self.llm._call(prompt=prompt, for result, history in self.llm._call(prompt=prompt,
history=chat_history): history=chat_history):
history[-1][0] = query history[-1][0] = query
......
...@@ -27,6 +27,9 @@ llm_model_dict = { ...@@ -27,6 +27,9 @@ llm_model_dict = {
# LLM model name # LLM model name
LLM_MODEL = "chatglm-6b" LLM_MODEL = "chatglm-6b"
# LLM streaming reponse
STREAMING = True
# Use p-tuning-v2 PrefixEncoder # Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False USE_PTUNING_V2 = False
......
...@@ -30,8 +30,8 @@ local_doc_qa = LocalDocQA() ...@@ -30,8 +30,8 @@ local_doc_qa = LocalDocQA()
def get_answer(query, vs_path, history, mode): def get_answer(query, vs_path, history, mode):
if mode == "知识库问答": if mode == "知识库问答" and vs_path:
if vs_path: if local_doc_qa.llm.streaming:
for resp, history in local_doc_qa.get_knowledge_based_answer( for resp, history in local_doc_qa.get_knowledge_based_answer(
query=query, vs_path=vs_path, chat_history=history): query=query, vs_path=vs_path, chat_history=history):
source = "\n\n" source = "\n\n"
...@@ -44,14 +44,28 @@ def get_answer(query, vs_path, history, mode): ...@@ -44,14 +44,28 @@ def get_answer(query, vs_path, history, mode):
history[-1][-1] += source history[-1][-1] += source
yield history, "" yield history, ""
else: else:
resp, history = local_doc_qa.get_knowledge_based_answer(
query=query, vs_path=vs_path, chat_history=history)
source = "\n\n"
source += "".join(
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
f"""{doc.page_content}\n"""
f"""</details>"""
for i, doc in
enumerate(resp["source_documents"])])
history[-1][-1] += source
return history, ""
else:
if local_doc_qa.llm.streaming:
for resp, history in local_doc_qa.llm._call(query, history): for resp, history in local_doc_qa.llm._call(query, history):
history[-1][-1] = resp + ( history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, "" yield history, ""
else: else:
for resp, history in local_doc_qa.llm._call(query, history): resp, history = local_doc_qa.llm._call(query, history)
history[-1][-1] = resp history[-1][-1] = resp + (
yield history, "" "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
return history, ""
def update_status(history, status): def update_status(history, status):
...@@ -62,7 +76,7 @@ def update_status(history, status): ...@@ -62,7 +76,7 @@ def update_status(history, status):
def init_model(): def init_model():
try: try:
local_doc_qa.init_cfg() local_doc_qa.init_cfg(streaming=STREAMING)
local_doc_qa.llm._call("你好") local_doc_qa.llm._call("你好")
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
print(reply) print(reply)
...@@ -84,7 +98,8 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to ...@@ -84,7 +98,8 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
embedding_model=embedding_model, embedding_model=embedding_model,
llm_history_len=llm_history_len, llm_history_len=llm_history_len,
use_ptuning_v2=use_ptuning_v2, use_ptuning_v2=use_ptuning_v2,
top_k=top_k) top_k=top_k,
streaming=STREAMING)
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
print(model_status) print(model_status)
except Exception as e: except Exception as e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论