提交 8ae84c6c 作者: imClumsyPanda

update local_doc_qa.py

上级 88941d39
......@@ -9,6 +9,7 @@ import os
from configs.model_config import *
import datetime
from typing import List
from textsplitter import ChineseTextSplitter
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6
......@@ -17,6 +18,18 @@ VECTOR_SEARCH_TOP_K = 6
LLM_HISTORY_LEN = 3
def load_file(filepath):
if filepath.lower().endswith(".pdf"):
loader = UnstructuredFileLoader(filepath)
textsplitter = ChineseTextSplitter(pdf=True)
docs = loader.load_and_split(textsplitter)
else:
loader = UnstructuredFileLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False)
docs = loader.load_and_split(text_splitter=textsplitter)
return docs
class LocalDocQA:
llm: object = None
embeddings: object = None
......@@ -48,10 +61,10 @@ class LocalDocQA:
elif os.path.isfile(filepath):
file = os.path.split(filepath)[-1]
try:
loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load()
docs = load_file(filepath)
print(f"{file} 已成功加载")
except:
except Exception as e:
print(e)
print(f"{file} 未能成功加载")
return None
elif os.path.isdir(filepath):
......@@ -59,25 +72,25 @@ class LocalDocQA:
for file in os.listdir(filepath):
fullfilepath = os.path.join(filepath, file)
try:
loader = UnstructuredFileLoader(fullfilepath, mode="elements")
docs += loader.load()
docs += load_file(fullfilepath)
print(f"{file} 已成功加载")
except:
except Exception as e:
print(e)
print(f"{file} 未能成功加载")
else:
docs = []
for file in filepath:
try:
loader = UnstructuredFileLoader(file, mode="elements")
docs += loader.load()
docs += load_file(file)
print(f"{file} 已成功加载")
except:
except Exception as e:
print(e)
print(f"{file} 未能成功加载")
vector_store = FAISS.from_documents(docs, self.embeddings)
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store.save_local(vs_path)
return vs_path if len(docs)>0 else None
return vs_path if len(docs) > 0 else None
def get_knowledge_based_answer(self,
query,
......
......@@ -2,7 +2,7 @@ from configs.model_config import *
from chains.local_doc_qa import LocalDocQA
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 10
VECTOR_SEARCH_TOP_K = 6
# LLM input history length
LLM_HISTORY_LEN = 3
......
from .chinese_text_splitter import *
\ No newline at end of file
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
def split_text(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论