提交 a4e67a67 作者: imClumsyPanda

update local_doc_qa.py

上级 c98c4888
...@@ -82,15 +82,19 @@ def similarity_search_with_score_by_vector( ...@@ -82,15 +82,19 @@ def similarity_search_with_score_by_vector(
id_set.add(i) id_set.add(i)
docs_len = len(doc.page_content) docs_len = len(doc.page_content)
for k in range(1, max(i, len(docs) - i)): for k in range(1, max(i, len(docs) - i)):
break_flag = False
for l in [i + k, i - k]: for l in [i + k, i - k]:
if 0 <= l < len(self.index_to_docstore_id): if 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l] _id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0) doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size: if docs_len + len(doc0.page_content) > self.chunk_size:
break_flag=True
break break
elif doc0.metadata["source"] == doc.metadata["source"]: elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content) docs_len += len(doc0.page_content)
id_set.add(l) id_set.add(l)
if break_flag:
break
id_list = sorted(list(id_set)) id_list = sorted(list(id_set))
id_lists = seperate_list(id_list) id_lists = seperate_list(id_list)
for id_seq in id_lists: for id_seq in id_lists:
...@@ -225,8 +229,8 @@ class LocalDocQA: ...@@ -225,8 +229,8 @@ class LocalDocQA:
if __name__ == "__main__": if __name__ == "__main__":
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg() local_doc_qa.init_cfg()
query = "你好" query = "本项目使用的embedding模型是什么,消耗多少显存"
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/123" vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
last_print_len = 0 last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path, vs_path=vs_path,
...@@ -234,9 +238,14 @@ if __name__ == "__main__": ...@@ -234,9 +238,14 @@ if __name__ == "__main__":
streaming=True): streaming=True):
print(resp["result"][last_print_len:], end="", flush=True) print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"]) last_print_len = len(resp["result"])
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
vs_path=vs_path, # f"""相关度:{doc.metadata['score']}\n\n"""
chat_history=[], for inum, doc in
streaming=False): enumerate(resp["source_documents"])]
print(resp["result"]) print("\n\n" + "\n\n".join(source_text))
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
# vs_path=vs_path,
# chat_history=[],
# streaming=False):
# print(resp["result"])
pass pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论