提交 87420019 作者: fxjhello
......@@ -20,7 +20,46 @@ from models.loader import LoaderCheckPoint
import models.shared as shared
from agent import bing_search
from langchain.docstore.document import Document
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from sklearn.neighbors import NearestNeighbors
class SemanticSearch:
def __init__(self):
self.use= SentenceTransformer('GanymedeNil_text2vec-large-chinese')
self.fitted = False
def fit(self, data, batch=100, n_neighbors=10):
self.data = data
self.embeddings = self.get_text_embedding(data, batch=batch)
n_neighbors = min(n_neighbors, len(self.embeddings))
self.nn = NearestNeighbors(n_neighbors=n_neighbors)
self.nn.fit(self.embeddings)
self.fitted = True
def __call__(self, text, return_data=True):
inp_emb = self.use.encode([text])
neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
if return_data:
return [self.data[i] for i in neighbors]
else:
return neighbors
def get_text_embedding(self, texts, batch=100):
embeddings = []
for i in range(0, len(texts), batch):
text_batch = texts[i : (i + batch)]
emb_batch = self.use.encode(text_batch)
embeddings.append(emb_batch)
embeddings = np.vstack(embeddings)
return embeddings
def get_docs_with_score(docs_with_score):
docs = []
for doc, score in docs_with_score:
doc.metadata["score"] = score
docs.append(doc)
return docs
def load_file(filepath, sentence_size=SENTENCE_SIZE):
if filepath.lower().endswith(".md"):
......@@ -262,9 +301,41 @@ class LocalDocQA:
vector_store.chunk_conent = self.chunk_conent
vector_store.score_threshold = self.score_threshold
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
###########################################精排 之前faiss检索作为粗排 需要设置model_config参数VECTOR_SEARCH_TOP_K =300
###########################################原理:粗排:faiss+semantic search 检索得到大量相关文档,需要设置ECTOR_SEARCH_TOP为300,然后合并文档,重新切分,
#############################################利用knn+ semantic search 进行二次检索,输入到prompt
####提取文档
related_docs = get_docs_with_score(related_docs_with_score)
text_batch0=[]
for i in range(len(related_docs)):
cut_txt = " ".join([w for w in list(related_docs[i].page_content)])
cut_txt =cut_txt.replace(" ", "")
text_batch0.append(cut_txt)
######文档去重
text_batch_new=[]
for i in range(len(text_batch0)):
if text_batch0[i] in text_batch_new:
continue
else:
while text_batch_new and text_batch_new[-1] > text_batch0[i] and text_batch_new[-1] in text_batch0[i + 1:]:
text_batch_new.pop() # 弹出栈顶元素
text_batch_new.append(text_batch0[i])
text_batch_new0 = "\n".join([doc for doc in text_batch_new])
###精排 采用knn和semantic search
recommender = SemanticSearch()
chunks = text_to_chunks(text_batch_new0, start_page=1)
recommender.fit(chunks)
topn_chunks = recommender(query)
torch_gc()
prompt = generate_prompt(related_docs_with_score, query)
#去掉文字中的空格
topn_chunks0=[]
for i in range(len(topn_chunks)):
cut_txt =topn_chunks[i].replace(" ", "")
topn_chunks0.append(cut_txt)
############生成prompt
prompt = generate_prompt(topn_chunks0, query)
########################
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
resp = answer_result.llm_output["answer"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论