提交 1c51d6ca 作者: imClumsyPanda

update cli_demo.py

上级 5bd66482
...@@ -28,7 +28,8 @@ class LocalDocQA: ...@@ -28,7 +28,8 @@ class LocalDocQA:
embedding_device=EMBEDDING_DEVICE, embedding_device=EMBEDDING_DEVICE,
llm_history_len: int = LLM_HISTORY_LEN, llm_history_len: int = LLM_HISTORY_LEN,
llm_model: str = LLM_MODEL, llm_model: str = LLM_MODEL,
llm_device=LLM_DEVICE llm_device=LLM_DEVICE,
top_k=VECTOR_SEARCH_TOP_K,
): ):
self.llm = ChatGLM() self.llm = ChatGLM()
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
...@@ -38,6 +39,7 @@ class LocalDocQA: ...@@ -38,6 +39,7 @@ class LocalDocQA:
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], ) self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
device=embedding_device) device=embedding_device)
self.top_k = top_k
def init_knowledge_vector_store(self, def init_knowledge_vector_store(self,
filepath: str): filepath: str):
...@@ -65,15 +67,14 @@ class LocalDocQA: ...@@ -65,15 +67,14 @@ class LocalDocQA:
print(f"{file} 未能成功加载") print(f"{file} 未能成功加载")
vector_store = FAISS.from_documents(docs, self.embeddings) vector_store = FAISS.from_documents(docs, self.embeddings)
vs_path = f"""./vector_store/{os.path.splitext(file)}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
return vs_path return vs_path
def get_knowledge_based_answer(self, def get_knowledge_based_answer(self,
query, query,
vs_path, vs_path,
chat_history=[], chat_history=[],):
top_k=VECTOR_SEARCH_TOP_K):
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
...@@ -90,7 +91,7 @@ class LocalDocQA: ...@@ -90,7 +91,7 @@ class LocalDocQA:
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISS.load_local(vs_path, self.embeddings)
knowledge_chain = RetrievalQA.from_llm( knowledge_chain = RetrievalQA.from_llm(
llm=self.llm, llm=self.llm,
retriever=vector_store.as_retriever(search_kwargs={"k": top_k}), retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
prompt=prompt prompt=prompt
) )
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
......
...@@ -15,7 +15,8 @@ if __name__ == "__main__": ...@@ -15,7 +15,8 @@ if __name__ == "__main__":
local_doc_qa.init_cfg(llm_model=LLM_MODEL, local_doc_qa.init_cfg(llm_model=LLM_MODEL,
embedding_model=EMBEDDING_MODEL, embedding_model=EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE, embedding_device=EMBEDDING_DEVICE,
llm_history_len=LLM_HISTORY_LEN) llm_history_len=LLM_HISTORY_LEN,
top_k=VECTOR_SEARCH_TOP_K)
vs_path = None vs_path = None
while not vs_path: while not vs_path:
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论