提交 c613c41d 作者: imClumsyPanda

update local_doc_qa.py

上级 e8b2ddea
...@@ -10,6 +10,7 @@ from langchain.docstore.document import Document ...@@ -10,6 +10,7 @@ from langchain.docstore.document import Document
import numpy as np import numpy as np
from utils import torch_gc from utils import torch_gc
from tqdm import tqdm from tqdm import tqdm
from pypinyin import lazy_pinyin
DEVICE_ = EMBEDDING_DEVICE DEVICE_ = EMBEDDING_DEVICE
...@@ -76,14 +77,14 @@ def similarity_search_with_score_by_vector( ...@@ -76,14 +77,14 @@ def similarity_search_with_score_by_vector(
doc = self.docstore.search(_id) doc = self.docstore.search(_id)
id_set.add(i) id_set.add(i)
docs_len = len(doc.page_content) docs_len = len(doc.page_content)
for k in range(1, max(i, store_len-i)): for k in range(1, max(i, store_len - i)):
break_flag = False break_flag = False
for l in [i + k, i - k]: for l in [i + k, i - k]:
if 0 <= l < len(self.index_to_docstore_id): if 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l] _id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0) doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size: if docs_len + len(doc0.page_content) > self.chunk_size:
break_flag=True break_flag = True
break break
elif doc0.metadata["source"] == doc.metadata["source"]: elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content) docs_len += len(doc0.page_content)
...@@ -166,7 +167,7 @@ class LocalDocQA: ...@@ -166,7 +167,7 @@ class LocalDocQA:
if len(failed_files) > 0: if len(failed_files) > 0:
print("以下文件未能成功加载:") print("以下文件未能成功加载:")
for file in failed_files: for file in failed_files:
print(file,end="\n") print(file, end="\n")
else: else:
docs = [] docs = []
...@@ -187,7 +188,7 @@ class LocalDocQA: ...@@ -187,7 +188,7 @@ class LocalDocQA:
else: else:
if not vs_path: if not vs_path:
vs_path = os.path.join(VS_ROOT_PATH, vs_path = os.path.join(VS_ROOT_PATH,
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
vector_store = FAISS.from_documents(docs, self.embeddings) vector_store = FAISS.from_documents(docs, self.embeddings)
torch_gc() torch_gc()
......
...@@ -13,4 +13,5 @@ gradio==3.28.3 ...@@ -13,4 +13,5 @@ gradio==3.28.3
fastapi fastapi
uvicorn uvicorn
peft peft
pypinyin
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2 #detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论