提交 a0b312d7 作者: glide-the

Merge remote-tracking branch 'origin/dev' into dev

# Conflicts:
#	configs/model_config.py
...@@ -14,8 +14,14 @@ ...@@ -14,8 +14,14 @@
![实现原理图](img/langchain+chatglm.png) ![实现原理图](img/langchain+chatglm.png)
从文档处理角度来看,实现流程如下:
![实现原理图2](img/langchain+chatglm2.png)
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59) 📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
## 变更日志 ## 变更日志
...@@ -166,6 +172,6 @@ Web UI 可以实现如下功能: ...@@ -166,6 +172,6 @@ Web UI 可以实现如下功能:
- [ ] 实现调用 API 的 Web UI Demo - [ ] 实现调用 API 的 Web UI Demo
## 项目交流群 ## 项目交流群
![二维码](img/qr_code_9.jpg) ![二维码](img/qr_code_10.jpg)
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
...@@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter ...@@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter
from typing import List, Tuple from typing import List, Tuple
from langchain.docstore.document import Document from langchain.docstore.document import Document
import numpy as np import numpy as np
from utils import torch_gc
# return top-k text chunk from vector store # return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6 VECTOR_SEARCH_TOP_K = 6
...@@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6 ...@@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6
# LLM input history length # LLM input history length
LLM_HISTORY_LEN = 3 LLM_HISTORY_LEN = 3
DEVICE_ = EMBEDDING_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
def load_file(filepath): def load_file(filepath):
if filepath.lower().endswith(".md"): if filepath.lower().endswith(".md"):
...@@ -30,6 +35,7 @@ def load_file(filepath): ...@@ -30,6 +35,7 @@ def load_file(filepath):
docs = loader.load_and_split(text_splitter=textsplitter) docs = loader.load_and_split(text_splitter=textsplitter)
return docs return docs
def generate_prompt(related_docs: List[str], def generate_prompt(related_docs: List[str],
query: str, query: str,
prompt_template=PROMPT_TEMPLATE) -> str: prompt_template=PROMPT_TEMPLATE) -> str:
...@@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str], ...@@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str],
def get_docs_with_score(docs_with_score): def get_docs_with_score(docs_with_score):
docs=[] docs = []
for doc, score in docs_with_score: for doc, score in docs_with_score:
doc.metadata["score"] = score doc.metadata["score"] = score
docs.append(doc) docs.append(doc)
...@@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]: ...@@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
lists = [] lists = []
ls1 = [ls[0]] ls1 = [ls[0]]
for i in range(1, len(ls)): for i in range(1, len(ls)):
if ls[i-1] + 1 == ls[i]: if ls[i - 1] + 1 == ls[i]:
ls1.append(ls[i]) ls1.append(ls[i])
else: else:
lists.append(ls1) lists.append(ls1)
...@@ -59,49 +65,52 @@ def seperate_list(ls: List[int]) -> List[List[int]]: ...@@ -59,49 +65,52 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
return lists return lists
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = [] docs = []
id_set = set() id_set = set()
for j, i in enumerate(indices[0]): for j, i in enumerate(indices[0]):
if i == -1: if i == -1:
# This happens when not enough docs are returned. # This happens when not enough docs are returned.
continue continue
_id = self.index_to_docstore_id[i] _id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id) doc = self.docstore.search(_id)
id_set.add(i) id_set.add(i)
docs_len = len(doc.page_content) docs_len = len(doc.page_content)
for k in range(1, max(i, len(docs)-i)): for k in range(1, max(i, len(docs) - i)):
for l in [i+k, i-k]: break_flag = False
if 0 <= l < len(self.index_to_docstore_id): for l in [i + k, i - k]:
_id0 = self.index_to_docstore_id[l] if 0 <= l < len(self.index_to_docstore_id):
doc0 = self.docstore.search(_id0) _id0 = self.index_to_docstore_id[l]
if docs_len + len(doc0.page_content) > self.chunk_size:
break
elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content)
id_set.add(l)
id_list = sorted(list(id_set))
id_lists = seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
doc = self.docstore.search(_id)
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0) doc0 = self.docstore.search(_id0)
doc.page_content += doc0.page_content if docs_len + len(doc0.page_content) > self.chunk_size:
if not isinstance(doc, Document): break_flag=True
raise ValueError(f"Could not find document for id {_id}, got {doc}") break
docs.append((doc, scores[0][j])) elif doc0.metadata["source"] == doc.metadata["source"]:
return docs docs_len += len(doc0.page_content)
id_set.add(l)
if break_flag:
break
id_list = sorted(list(id_set))
id_lists = seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
doc = self.docstore.search(_id)
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0)
doc.page_content += doc0.page_content
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, scores[0][j]))
torch_gc(DEVICE)
return docs
class LocalDocQA: class LocalDocQA:
...@@ -172,10 +181,12 @@ class LocalDocQA: ...@@ -172,10 +181,12 @@ class LocalDocQA:
if vs_path and os.path.isdir(vs_path): if vs_path and os.path.isdir(vs_path):
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
torch_gc(DEVICE)
else: else:
if not vs_path: if not vs_path:
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store = FAISS.from_documents(docs, self.embeddings) vector_store = FAISS.from_documents(docs, self.embeddings)
torch_gc(DEVICE)
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
return vs_path, loaded_files return vs_path, loaded_files
...@@ -187,29 +198,54 @@ class LocalDocQA: ...@@ -187,29 +198,54 @@ class LocalDocQA:
query, query,
vs_path, vs_path,
chat_history=[], chat_history=[],
streaming=True): streaming: bool = STREAMING):
self.llm.streaming = streaming
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
vector_store.chunk_size=self.chunk_size vector_store.chunk_size = self.chunk_size
related_docs_with_score = vector_store.similarity_search_with_score(query, related_docs_with_score = vector_store.similarity_search_with_score(query,
k=self.top_k) k=self.top_k)
related_docs = get_docs_with_score(related_docs_with_score) related_docs = get_docs_with_score(related_docs_with_score)
prompt = generate_prompt(related_docs, query) prompt = generate_prompt(related_docs, query)
if streaming: # if streaming:
for result, history in self.llm._call(prompt=prompt, # for result, history in self.llm._stream_call(prompt=prompt,
history=chat_history): # history=chat_history):
history[-1][0] = query # history[-1][0] = query
response = {"query": query, # response = {"query": query,
"result": result, # "result": result,
"source_documents": related_docs} # "source_documents": related_docs}
yield response, history # yield response, history
else: # else:
result, history = self.llm._call(prompt=prompt, for result, history in self.llm._call(prompt=prompt,
history=chat_history) history=chat_history,
streaming=streaming):
history[-1][0] = query history[-1][0] = query
response = {"query": query, response = {"query": query,
"result": result, "result": result,
"source_documents": related_docs} "source_documents": related_docs}
return response, history yield response, history
if __name__ == "__main__":
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg()
query = "本项目使用的embedding模型是什么,消耗多少显存"
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=[],
streaming=True):
print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"])
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in
enumerate(resp["source_documents"])]
print("\n\n" + "\n\n".join(source_text))
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
# vs_path=vs_path,
# chat_history=[],
# streaming=False):
# print(resp["result"])
pass
...@@ -32,9 +32,12 @@ if __name__ == "__main__": ...@@ -32,9 +32,12 @@ if __name__ == "__main__":
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path, vs_path=vs_path,
chat_history=history, chat_history=history,
streaming=True): streaming=STREAMING):
print(resp["result"][last_print_len:], end="", flush=True) if STREAMING:
last_print_len = len(resp["result"]) print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"])
else:
print(resp["result"])
if REPLY_WITH_SOURCE: if REPLY_WITH_SOURCE:
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n"""
......
...@@ -27,6 +27,9 @@ llm_model_dict = { ...@@ -27,6 +27,9 @@ llm_model_dict = {
# LLM model name # LLM model name
LLM_MODEL = "chatglm-6b" LLM_MODEL = "chatglm-6b"
# LLM streaming reponse
STREAMING = True
# Use p-tuning-v2 PrefixEncoder # Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False USE_PTUNING_V2 = False
...@@ -38,14 +41,10 @@ VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_ ...@@ -38,14 +41,10 @@ VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_
UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "") UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "")
# 基于上下文的prompt模版,请务必保留"{question}"和"{context}" # 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
PROMPT_TEMPLATE = """已知信息在下方"="包裹的段落,基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 PROMPT_TEMPLATE = """已知信息:
====================================已知信息=====================================================
{context} {context}
================================================================================================
问题:"{question}" 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
答案:"""
# 匹配后单段上下文长度 # 匹配后单段上下文长度
CHUNK_SIZE = 500 CHUNK_SIZE = 500
\ No newline at end of file
...@@ -4,21 +4,15 @@ from typing import Optional, List ...@@ -4,21 +4,15 @@ from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel, AutoConfig from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch import torch
from configs.model_config import LLM_DEVICE from configs.model_config import *
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from typing import Dict, Tuple, Union, Optional from typing import Dict, Tuple, Union, Optional
from utils import torch_gc
DEVICE = LLM_DEVICE DEVICE_ = LLM_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None DEVICE_ID = "0" if torch.cuda.is_available() else None
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
...@@ -59,7 +53,6 @@ class ChatGLM(LLM): ...@@ -59,7 +53,6 @@ class ChatGLM(LLM):
tokenizer: object = None tokenizer: object = None
model: object = None model: object = None
history_len: int = 10 history_len: int = 10
streaming: bool = True
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
def __init__(self): def __init__(self):
...@@ -72,8 +65,8 @@ class ChatGLM(LLM): ...@@ -72,8 +65,8 @@ class ChatGLM(LLM):
def _call(self, def _call(self,
prompt: str, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
stop: Optional[List[str]] = None) -> str: streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
if self.streaming: if streaming:
for inum, (stream_resp, _) in enumerate(self.model.stream_chat( for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
...@@ -81,25 +74,23 @@ class ChatGLM(LLM): ...@@ -81,25 +74,23 @@ class ChatGLM(LLM):
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
)): )):
torch_gc(DEVICE)
if inum == 0: if inum == 0:
history += [[prompt, stream_resp]] history += [[prompt, stream_resp]]
else: else:
history[-1] = [prompt, stream_resp] history[-1] = [prompt, stream_resp]
yield stream_resp, history yield stream_resp, history
else: else:
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) )
torch_gc() torch_gc(DEVICE)
if stop is not None: history += [[prompt, response]]
response = enforce_stop_tokens(response, stop) yield response, history
history = history + [[None, response]]
return response, history
# def chat(self, # def chat(self,
# prompt: str) -> str: # prompt: str) -> str:
...@@ -191,3 +182,16 @@ class ChatGLM(LLM): ...@@ -191,3 +182,16 @@ class ChatGLM(LLM):
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval() self.model = self.model.eval()
if __name__ == "__main__":
llm = ChatGLM()
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
llm_device=LLM_DEVICE, )
last_print_len=0
for resp, history in llm._call("你好", streaming=True):
print(resp[last_print_len:], end="", flush=True)
last_print_len = len(resp)
for resp, history in llm._call("你好", streaming=False):
print(resp)
pass
import torch.cuda
import torch.mps
import torch.backends
def torch_gc(DEVICE):
if torch.cuda.is_available():
with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
\ No newline at end of file
...@@ -29,28 +29,28 @@ llm_model_dict_list = list(llm_model_dict.keys()) ...@@ -29,28 +29,28 @@ llm_model_dict_list = list(llm_model_dict.keys())
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
def get_answer(query, vs_path, history, mode): def get_answer(query, vs_path, history, mode,
if mode == "知识库问答": streaming: bool = STREAMING):
if vs_path: if mode == "知识库问答" and vs_path:
for resp, history in local_doc_qa.get_knowledge_based_answer( for resp, history in local_doc_qa.get_knowledge_based_answer(
query=query, vs_path=vs_path, chat_history=history): query=query,
source = "\n\n" vs_path=vs_path,
source += "".join( chat_history=history,
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n""" streaming=streaming):
f"""{doc.page_content}\n""" source = "\n\n"
f"""</details>""" source += "".join(
for i, doc in [f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
enumerate(resp["source_documents"])]) f"""{doc.page_content}\n"""
history[-1][-1] += source f"""</details>"""
yield history, "" for i, doc in
else: enumerate(resp["source_documents"])])
for resp, history in local_doc_qa.llm._call(query, history): history[-1][-1] += source
history[-1][-1] = resp + ( yield history, ""
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
else: else:
for resp, history in local_doc_qa.llm._call(query, history): for resp, history in local_doc_qa.llm._call(query, history,
history[-1][-1] = resp streaming=streaming):
history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, "" yield history, ""
...@@ -84,7 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to ...@@ -84,7 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
embedding_model=embedding_model, embedding_model=embedding_model,
llm_history_len=llm_history_len, llm_history_len=llm_history_len,
use_ptuning_v2=use_ptuning_v2, use_ptuning_v2=use_ptuning_v2,
top_k=top_k) top_k=top_k,)
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
print(model_status) print(model_status)
except Exception as e: except Exception as e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论