提交 4df9d76f 作者: imClumsyPanda

add streaming option in configs/model_config.py

上级 0e8cc0d1
...@@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter ...@@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter
from typing import List, Tuple from typing import List, Tuple
from langchain.docstore.document import Document from langchain.docstore.document import Document
import numpy as np import numpy as np
from utils import torch_gc
# return top-k text chunk from vector store # return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6 VECTOR_SEARCH_TOP_K = 6
...@@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6 ...@@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6
# LLM input history length # LLM input history length
LLM_HISTORY_LEN = 3 LLM_HISTORY_LEN = 3
DEVICE_ = EMBEDDING_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
def load_file(filepath): def load_file(filepath):
if filepath.lower().endswith(".md"): if filepath.lower().endswith(".md"):
...@@ -30,6 +35,7 @@ def load_file(filepath): ...@@ -30,6 +35,7 @@ def load_file(filepath):
docs = loader.load_and_split(text_splitter=textsplitter) docs = loader.load_and_split(text_splitter=textsplitter)
return docs return docs
def generate_prompt(related_docs: List[str], def generate_prompt(related_docs: List[str],
query: str, query: str,
prompt_template=PROMPT_TEMPLATE) -> str: prompt_template=PROMPT_TEMPLATE) -> str:
...@@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str], ...@@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str],
def get_docs_with_score(docs_with_score): def get_docs_with_score(docs_with_score):
docs=[] docs = []
for doc, score in docs_with_score: for doc, score in docs_with_score:
doc.metadata["score"] = score doc.metadata["score"] = score
docs.append(doc) docs.append(doc)
...@@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]: ...@@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
lists = [] lists = []
ls1 = [ls[0]] ls1 = [ls[0]]
for i in range(1, len(ls)): for i in range(1, len(ls)):
if ls[i-1] + 1 == ls[i]: if ls[i - 1] + 1 == ls[i]:
ls1.append(ls[i]) ls1.append(ls[i])
else: else:
lists.append(ls1) lists.append(ls1)
...@@ -59,49 +65,48 @@ def seperate_list(ls: List[int]) -> List[List[int]]: ...@@ -59,49 +65,48 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
return lists return lists
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = [] docs = []
id_set = set() id_set = set()
for j, i in enumerate(indices[0]): for j, i in enumerate(indices[0]):
if i == -1: if i == -1:
# This happens when not enough docs are returned. # This happens when not enough docs are returned.
continue continue
_id = self.index_to_docstore_id[i] _id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id) doc = self.docstore.search(_id)
id_set.add(i) id_set.add(i)
docs_len = len(doc.page_content) docs_len = len(doc.page_content)
for k in range(1, max(i, len(docs)-i)): for k in range(1, max(i, len(docs) - i)):
for l in [i+k, i-k]: for l in [i + k, i - k]:
if 0 <= l < len(self.index_to_docstore_id): if 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l] _id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size:
break
elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content)
id_set.add(l)
id_list = sorted(list(id_set))
id_lists = seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
doc = self.docstore.search(_id)
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0) doc0 = self.docstore.search(_id0)
doc.page_content += doc0.page_content if docs_len + len(doc0.page_content) > self.chunk_size:
if not isinstance(doc, Document): break
raise ValueError(f"Could not find document for id {_id}, got {doc}") elif doc0.metadata["source"] == doc.metadata["source"]:
docs.append((doc, scores[0][j])) docs_len += len(doc0.page_content)
return docs id_set.add(l)
id_list = sorted(list(id_set))
id_lists = seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
doc = self.docstore.search(_id)
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0)
doc.page_content += doc0.page_content
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, scores[0][j]))
torch_gc(DEVICE)
return docs
class LocalDocQA: class LocalDocQA:
...@@ -116,12 +121,10 @@ class LocalDocQA: ...@@ -116,12 +121,10 @@ class LocalDocQA:
llm_history_len: int = LLM_HISTORY_LEN, llm_history_len: int = LLM_HISTORY_LEN,
llm_model: str = LLM_MODEL, llm_model: str = LLM_MODEL,
llm_device=LLM_DEVICE, llm_device=LLM_DEVICE,
streaming=STREAMING,
top_k=VECTOR_SEARCH_TOP_K, top_k=VECTOR_SEARCH_TOP_K,
use_ptuning_v2: bool = USE_PTUNING_V2 use_ptuning_v2: bool = USE_PTUNING_V2
): ):
self.llm = ChatGLM() self.llm = ChatGLM()
self.llm.streaming = streaming
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
llm_device=llm_device, llm_device=llm_device,
use_ptuning_v2=use_ptuning_v2) use_ptuning_v2=use_ptuning_v2)
...@@ -174,10 +177,12 @@ class LocalDocQA: ...@@ -174,10 +177,12 @@ class LocalDocQA:
if vs_path and os.path.isdir(vs_path): if vs_path and os.path.isdir(vs_path):
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
torch_gc(DEVICE)
else: else:
if not vs_path: if not vs_path:
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store = FAISS.from_documents(docs, self.embeddings) vector_store = FAISS.from_documents(docs, self.embeddings)
torch_gc(DEVICE)
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
return vs_path, loaded_files return vs_path, loaded_files
...@@ -188,28 +193,50 @@ class LocalDocQA: ...@@ -188,28 +193,50 @@ class LocalDocQA:
def get_knowledge_based_answer(self, def get_knowledge_based_answer(self,
query, query,
vs_path, vs_path,
chat_history=[]): chat_history=[],
streaming: bool = STREAMING):
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
vector_store.chunk_size=self.chunk_size vector_store.chunk_size = self.chunk_size
related_docs_with_score = vector_store.similarity_search_with_score(query, related_docs_with_score = vector_store.similarity_search_with_score(query,
k=self.top_k) k=self.top_k)
related_docs = get_docs_with_score(related_docs_with_score) related_docs = get_docs_with_score(related_docs_with_score)
prompt = generate_prompt(related_docs, query) prompt = generate_prompt(related_docs, query)
if self.llm.streaming: # if streaming:
for result, history in self.llm._call(prompt=prompt, # for result, history in self.llm._stream_call(prompt=prompt,
history=chat_history): # history=chat_history):
history[-1][0] = query # history[-1][0] = query
response = {"query": query, # response = {"query": query,
"result": result, # "result": result,
"source_documents": related_docs} # "source_documents": related_docs}
yield response, history # yield response, history
else: # else:
result, history = self.llm._call(prompt=prompt, for result, history in self.llm._call(prompt=prompt,
history=chat_history) history=chat_history,
streaming=streaming):
history[-1][0] = query history[-1][0] = query
response = {"query": query, response = {"query": query,
"result": result, "result": result,
"source_documents": related_docs} "source_documents": related_docs}
return response, history yield response, history
if __name__ == "__main__":
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg()
query = "你好"
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/123"
last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=[],
streaming=True):
print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"])
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=[],
streaming=False):
print(resp["result"])
pass
...@@ -32,9 +32,12 @@ if __name__ == "__main__": ...@@ -32,9 +32,12 @@ if __name__ == "__main__":
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path, vs_path=vs_path,
chat_history=history, chat_history=history,
streaming=True): streaming=STREAMING):
print(resp["result"][last_print_len:], end="", flush=True) if STREAMING:
last_print_len = len(resp["result"]) print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"])
else:
print(resp["result"])
if REPLY_WITH_SOURCE: if REPLY_WITH_SOURCE:
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n"""
......
...@@ -4,21 +4,15 @@ from typing import Optional, List ...@@ -4,21 +4,15 @@ from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel, AutoConfig from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch import torch
from configs.model_config import LLM_DEVICE from configs.model_config import *
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from typing import Dict, Tuple, Union, Optional from typing import Dict, Tuple, Union, Optional
from utils import torch_gc
DEVICE = LLM_DEVICE DEVICE_ = LLM_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None DEVICE_ID = "0" if torch.cuda.is_available() else None
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
...@@ -59,7 +53,6 @@ class ChatGLM(LLM): ...@@ -59,7 +53,6 @@ class ChatGLM(LLM):
tokenizer: object = None tokenizer: object = None
model: object = None model: object = None
history_len: int = 10 history_len: int = 10
streaming: bool = True
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
def __init__(self): def __init__(self):
...@@ -72,8 +65,8 @@ class ChatGLM(LLM): ...@@ -72,8 +65,8 @@ class ChatGLM(LLM):
def _call(self, def _call(self,
prompt: str, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
stop: Optional[List[str]] = None) -> str: streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
if self.streaming: if streaming:
for inum, (stream_resp, _) in enumerate(self.model.stream_chat( for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
...@@ -81,25 +74,23 @@ class ChatGLM(LLM): ...@@ -81,25 +74,23 @@ class ChatGLM(LLM):
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
)): )):
torch_gc(DEVICE)
if inum == 0: if inum == 0:
history += [[prompt, stream_resp]] history += [[prompt, stream_resp]]
else: else:
history[-1] = [prompt, stream_resp] history[-1] = [prompt, stream_resp]
yield stream_resp, history yield stream_resp, history
else: else:
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) )
torch_gc() torch_gc(DEVICE)
if stop is not None: history += [[prompt, response]]
response = enforce_stop_tokens(response, stop) yield response, history
history = history + [[None, response]]
return response, history
# def chat(self, # def chat(self,
# prompt: str) -> str: # prompt: str) -> str:
...@@ -191,3 +182,16 @@ class ChatGLM(LLM): ...@@ -191,3 +182,16 @@ class ChatGLM(LLM):
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval() self.model = self.model.eval()
if __name__ == "__main__":
llm = ChatGLM()
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
llm_device=LLM_DEVICE, )
last_print_len=0
for resp, history in llm._call("你好", streaming=True):
print(resp[last_print_len:], end="", flush=True)
last_print_len = len(resp)
for resp, history in llm._call("你好", streaming=False):
print(resp)
pass
import torch.cuda
import torch.mps
import torch.backends
def torch_gc(DEVICE):
if torch.cuda.is_available():
with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
\ No newline at end of file
...@@ -29,23 +29,14 @@ llm_model_dict_list = list(llm_model_dict.keys()) ...@@ -29,23 +29,14 @@ llm_model_dict_list = list(llm_model_dict.keys())
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
def get_answer(query, vs_path, history, mode): def get_answer(query, vs_path, history, mode,
streaming: bool = STREAMING):
if mode == "知识库问答" and vs_path: if mode == "知识库问答" and vs_path:
if local_doc_qa.llm.streaming: for resp, history in local_doc_qa.get_knowledge_based_answer(
for resp, history in local_doc_qa.get_knowledge_based_answer( query=query,
query=query, vs_path=vs_path, chat_history=history): vs_path=vs_path,
source = "\n\n" chat_history=history,
source += "".join( streaming=streaming):
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
f"""{doc.page_content}\n"""
f"""</details>"""
for i, doc in
enumerate(resp["source_documents"])])
history[-1][-1] += source
yield history, ""
else:
resp, history = local_doc_qa.get_knowledge_based_answer(
query=query, vs_path=vs_path, chat_history=history)
source = "\n\n" source = "\n\n"
source += "".join( source += "".join(
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n""" [f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
...@@ -54,18 +45,13 @@ def get_answer(query, vs_path, history, mode): ...@@ -54,18 +45,13 @@ def get_answer(query, vs_path, history, mode):
for i, doc in for i, doc in
enumerate(resp["source_documents"])]) enumerate(resp["source_documents"])])
history[-1][-1] += source history[-1][-1] += source
return history, "" yield history, ""
else: else:
if local_doc_qa.llm.streaming: for resp, history in local_doc_qa.llm._call(query, history,
for resp, history in local_doc_qa.llm._call(query, history): streaming=streaming):
history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
else:
resp, history = local_doc_qa.llm._call(query, history)
history[-1][-1] = resp + ( history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
return history, "" yield history, ""
def update_status(history, status): def update_status(history, status):
...@@ -76,7 +62,7 @@ def update_status(history, status): ...@@ -76,7 +62,7 @@ def update_status(history, status):
def init_model(): def init_model():
try: try:
local_doc_qa.init_cfg(streaming=STREAMING) local_doc_qa.init_cfg()
local_doc_qa.llm._call("你好") local_doc_qa.llm._call("你好")
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
print(reply) print(reply)
...@@ -98,8 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to ...@@ -98,8 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
embedding_model=embedding_model, embedding_model=embedding_model,
llm_history_len=llm_history_len, llm_history_len=llm_history_len,
use_ptuning_v2=use_ptuning_v2, use_ptuning_v2=use_ptuning_v2,
top_k=top_k, top_k=top_k,)
streaming=STREAMING)
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
print(model_status) print(model_status)
except Exception as e: except Exception as e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论