提交 88ab9a1d 作者: imClumsyPanda

update webui.py and local_doc_qa.py

上级 daafe8d5
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
# from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from chains.lib.embeddings import MyEmbeddings from langchain.vectorstores import FAISS
# from langchain.vectorstores import FAISS from langchain.vectorstores.base import VectorStoreRetriever
from chains.lib.vectorstores import FAISSVS
from langchain.document_loaders import UnstructuredFileLoader from langchain.document_loaders import UnstructuredFileLoader
from models.chatglm_llm import ChatGLM from models.chatglm_llm import ChatGLM
import sentence_transformers import sentence_transformers
...@@ -12,6 +11,7 @@ from configs.model_config import * ...@@ -12,6 +11,7 @@ from configs.model_config import *
import datetime import datetime
from typing import List from typing import List
from textsplitter import ChineseTextSplitter from textsplitter import ChineseTextSplitter
from langchain.docstore.document import Document
# return top-k text chunk from vector store # return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6 VECTOR_SEARCH_TOP_K = 6
...@@ -21,7 +21,10 @@ LLM_HISTORY_LEN = 3 ...@@ -21,7 +21,10 @@ LLM_HISTORY_LEN = 3
def load_file(filepath): def load_file(filepath):
if filepath.lower().endswith(".pdf"): if filepath.lower().endswith(".md"):
loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load()
elif filepath.lower().endswith(".pdf"):
loader = UnstructuredFileLoader(filepath) loader = UnstructuredFileLoader(filepath)
textsplitter = ChineseTextSplitter(pdf=True) textsplitter = ChineseTextSplitter(pdf=True)
docs = loader.load_and_split(textsplitter) docs = loader.load_and_split(textsplitter)
...@@ -32,6 +35,22 @@ def load_file(filepath): ...@@ -32,6 +35,22 @@ def load_file(filepath):
return docs return docs
def get_relevant_documents(self, query: str) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs)
for doc in docs:
doc[0].metadata["score"] = doc[1]
docs = [doc[0] for doc in docs]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
class LocalDocQA: class LocalDocQA:
llm: object = None llm: object = None
embeddings: object = None embeddings: object = None
...@@ -52,7 +71,7 @@ class LocalDocQA: ...@@ -52,7 +71,7 @@ class LocalDocQA:
use_ptuning_v2=use_ptuning_v2) use_ptuning_v2=use_ptuning_v2)
self.llm.history_len = llm_history_len self.llm.history_len = llm_history_len
self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model], self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
model_kwargs={'device': embedding_device}) model_kwargs={'device': embedding_device})
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
# device=embedding_device) # device=embedding_device)
...@@ -99,12 +118,12 @@ class LocalDocQA: ...@@ -99,12 +118,12 @@ class LocalDocQA:
print(f"{file} 未能成功加载") print(f"{file} 未能成功加载")
if len(docs) > 0: if len(docs) > 0:
if vs_path and os.path.isdir(vs_path): if vs_path and os.path.isdir(vs_path):
vector_store = FAISSVS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
else: else:
if not vs_path: if not vs_path:
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store = FAISSVS.from_documents(docs, self.embeddings) vector_store = FAISS.from_documents(docs, self.embeddings)
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
return vs_path, loaded_files return vs_path, loaded_files
...@@ -129,10 +148,13 @@ class LocalDocQA: ...@@ -129,10 +148,13 @@ class LocalDocQA:
input_variables=["context", "question"] input_variables=["context", "question"]
) )
self.llm.history = chat_history self.llm.history = chat_history
vector_store = FAISSVS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
vs_r = vector_store.as_retriever(search_type="mmr",
search_kwargs={"k": self.top_k})
# VectorStoreRetriever.get_relevant_documents = get_relevant_documents
knowledge_chain = RetrievalQA.from_llm( knowledge_chain = RetrievalQA.from_llm(
llm=self.llm, llm=self.llm,
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), retriever=vs_r,
prompt=prompt prompt=prompt
) )
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
...@@ -140,7 +162,6 @@ class LocalDocQA: ...@@ -140,7 +162,6 @@ class LocalDocQA:
) )
knowledge_chain.return_source_documents = True knowledge_chain.return_source_documents = True
result = knowledge_chain({"query": query}) result = knowledge_chain({"query": query})
self.llm.history[-1][0] = query self.llm.history[-1][0] = query
return result, self.llm.history return result, self.llm.history
...@@ -72,16 +72,16 @@ class ChatGLM(LLM): ...@@ -72,16 +72,16 @@ class ChatGLM(LLM):
stream=True) -> str: stream=True) -> str:
if stream: if stream:
self.history = self.history + [[None, ""]] self.history = self.history + [[None, ""]]
response, _ = self.model.stream_chat( for response, history in self.model.stream_chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=self.history[-self.history_len:] if self.history_len > 0 else [], history=self.history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) ):
torch_gc() torch_gc()
self.history[-1][-1] = response self.history[-1][-1] = response
yield response yield response
else: else:
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
......
...@@ -30,19 +30,28 @@ local_doc_qa = LocalDocQA() ...@@ -30,19 +30,28 @@ local_doc_qa = LocalDocQA()
def get_answer(query, vs_path, history, mode): def get_answer(query, vs_path, history, mode):
if vs_path and mode == "知识库问答": if mode == "知识库问答":
resp, history = local_doc_qa.get_knowledge_based_answer( if vs_path:
query=query, vs_path=vs_path, chat_history=history) for resp, history in local_doc_qa.get_knowledge_based_answer(
source = "".join([f"""<details> <summary>出处 {i + 1}</summary> query=query, vs_path=vs_path, chat_history=history):
{doc.page_content} # source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
# {doc.page_content}
<b>所属文件:</b>{doc.metadata["source"]} #
</details>""" for i, doc in enumerate(resp["source_documents"])]) # <b>所属文件:</b>{doc.metadata["source"]}
history[-1][-1] += source # </details>""" for i, doc in enumerate(resp["source_documents"])])
# history[-1][-1] += source
yield history, ""
else:
history = history + [[query, ""]]
for resp in local_doc_qa.llm._call(query):
history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
else: else:
resp = local_doc_qa.llm._call(query) history = history + [[query, ""]]
history = history + [[query, resp + ("\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")]] for resp in local_doc_qa.llm._call(query):
return history, "" history[-1][-1] = resp
yield history, ""
def update_status(history, status): def update_status(history, status):
...@@ -62,7 +71,7 @@ def init_model(): ...@@ -62,7 +71,7 @@ def init_model():
print(e) print(e)
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮""" reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
if str(e) == "Unknown platform: darwin": if str(e) == "Unknown platform: darwin":
print("报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:" print("报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
" https://github.com/imClumsyPanda/langchain-ChatGLM") " https://github.com/imClumsyPanda/langchain-ChatGLM")
else: else:
print(reply) print(reply)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论