提交 f1f742ce 作者: imClumsyPanda

add self-defined class MyFAISS

上级 89b986c3
from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS from vectorstores import MyFAISS
from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader
from configs.model_config import * from configs.model_config import *
import datetime import datetime
...@@ -32,7 +32,7 @@ HuggingFaceEmbeddings.__hash__ = _embeddings_hash ...@@ -32,7 +32,7 @@ HuggingFaceEmbeddings.__hash__ = _embeddings_hash
# will keep CACHED_VS_NUM of vector store caches # will keep CACHED_VS_NUM of vector store caches
@lru_cache(CACHED_VS_NUM) @lru_cache(CACHED_VS_NUM)
def load_vector_store(vs_path, embeddings): def load_vector_store(vs_path, embeddings):
return FAISS.load_local(vs_path, embeddings) return MyFAISS.load_local(vs_path, embeddings)
def tree(filepath, ignore_dir_names=None, ignore_file_names=None): def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
...@@ -107,78 +107,6 @@ def generate_prompt(related_docs: List[str], ...@@ -107,78 +107,6 @@ def generate_prompt(related_docs: List[str],
return prompt return prompt
def seperate_list(ls: List[int]) -> List[List[int]]:
lists = []
ls1 = [ls[0]]
for i in range(1, len(ls)):
if ls[i - 1] + 1 == ls[i]:
ls1.append(ls[i])
else:
lists.append(ls1)
ls1 = [ls[i]]
lists.append(ls1)
return lists
def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
) -> List[Tuple[Document, float]]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = []
id_set = set()
store_len = len(self.index_to_docstore_id)
for j, i in enumerate(indices[0]):
if i == -1 or 0 < self.score_threshold < scores[0][j]:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not self.chunk_conent:
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j])
docs.append(doc)
continue
id_set.add(i)
docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)):
break_flag = False
for l in [i + k, i - k]:
if 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size:
break_flag = True
break
elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content)
id_set.add(l)
if break_flag:
break
if not self.chunk_conent:
return docs
if len(id_set) == 0 and self.score_threshold > 0:
return []
id_list = sorted(list(id_set))
id_lists = seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
doc = self.docstore.search(_id)
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0)
doc.page_content += " " + doc0.page_content
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
doc.metadata["score"] = int(doc_score)
docs.append(doc)
torch_gc()
return docs
def search_result2docs(search_results): def search_result2docs(search_results):
docs = [] docs = []
for result in search_results: for result in search_results:
...@@ -263,7 +191,7 @@ class LocalDocQA: ...@@ -263,7 +191,7 @@ class LocalDocQA:
if not vs_path: if not vs_path:
vs_path = os.path.join(VS_ROOT_PATH, vs_path = os.path.join(VS_ROOT_PATH,
f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
vector_store = FAISS.from_documents(docs, self.embeddings) # docs 为Document列表 vector_store = MyFAISS.from_documents(docs, self.embeddings) # docs 为Document列表
torch_gc() torch_gc()
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
...@@ -281,11 +209,11 @@ class LocalDocQA: ...@@ -281,11 +209,11 @@ class LocalDocQA:
if not one_content_segmentation: if not one_content_segmentation:
text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = text_splitter.split_documents(docs) docs = text_splitter.split_documents(docs)
if os.path.isdir(vs_path) and os.path.isfile(vs_path+"/index.faiss"): if os.path.isdir(vs_path) and os.path.isfile(vs_path + "/index.faiss"):
vector_store = load_vector_store(vs_path, self.embeddings) vector_store = load_vector_store(vs_path, self.embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
else: else:
vector_store = FAISS.from_documents(docs, self.embeddings) ##docs 为Document列表 vector_store = MyFAISS.from_documents(docs, self.embeddings) ##docs 为Document列表
torch_gc() torch_gc()
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
return vs_path, [one_title] return vs_path, [one_title]
...@@ -295,13 +223,12 @@ class LocalDocQA: ...@@ -295,13 +223,12 @@ class LocalDocQA:
def get_knowledge_based_answer(self, query, vs_path, chat_history=[], streaming: bool = STREAMING): def get_knowledge_based_answer(self, query, vs_path, chat_history=[], streaming: bool = STREAMING):
vector_store = load_vector_store(vs_path, self.embeddings) vector_store = load_vector_store(vs_path, self.embeddings)
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
vector_store.chunk_size = self.chunk_size vector_store.chunk_size = self.chunk_size
vector_store.chunk_conent = self.chunk_conent vector_store.chunk_conent = self.chunk_conent
vector_store.score_threshold = self.score_threshold vector_store.score_threshold = self.score_threshold
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k) related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
torch_gc() torch_gc()
if len(related_docs_with_score)>0: if len(related_docs_with_score) > 0:
prompt = generate_prompt(related_docs_with_score, query) prompt = generate_prompt(related_docs_with_score, query)
else: else:
prompt = query prompt = query
...@@ -326,7 +253,7 @@ class LocalDocQA: ...@@ -326,7 +253,7 @@ class LocalDocQA:
score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE): vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE):
vector_store = load_vector_store(vs_path, self.embeddings) vector_store = load_vector_store(vs_path, self.embeddings)
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector # FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
vector_store.chunk_conent = chunk_conent vector_store.chunk_conent = chunk_conent
vector_store.score_threshold = score_threshold vector_store.score_threshold = score_threshold
vector_store.chunk_size = chunk_size vector_store.chunk_size = chunk_size
...@@ -381,8 +308,8 @@ if __name__ == "__main__": ...@@ -381,8 +308,8 @@ if __name__ == "__main__":
streaming=True): streaming=True):
print(resp["result"][last_print_len:], end="", flush=True) print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"]) last_print_len = len(resp["result"])
source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http") source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http")
else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in for inum, doc in
enumerate(resp["source_documents"])] enumerate(resp["source_documents"])]
......
from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.faiss import dependable_faiss_import
from typing import Any, Callable, List, Tuple, Dict
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
import numpy as np
class MyFAISS(FAISS, VectorStore):
def __init__(
self,
embedding_function: Callable,
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str],
normalize_L2: bool = False,
):
super().__init__(embedding_function=embedding_function,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2)
# def similarity_search_with_score_by_vector(
# self, embedding: List[float], k: int = 4
# ) -> List[Tuple[Document, float]]:
# faiss = dependable_faiss_import()
# vector = np.array([embedding], dtype=np.float32)
# if self._normalize_L2:
# faiss.normalize_L2(vector)
# scores, indices = self.index.search(vector, k)
# docs = []
# for j, i in enumerate(indices[0]):
# if i == -1:
# # This happens when not enough docs are returned.
# continue
# _id = self.index_to_docstore_id[i]
# doc = self.docstore.search(_id)
# if not isinstance(doc, Document):
# raise ValueError(f"Could not find document for id {_id}, got {doc}")
#
# docs.append((doc, scores[0][j]))
# return docs
def seperate_list(self, ls: List[int]) -> List[List[int]]:
# TODO: 增加是否属于同一文档的判断
lists = []
ls1 = [ls[0]]
for i in range(1, len(ls)):
if ls[i - 1] + 1 == ls[i]:
ls1.append(ls[i])
else:
lists.append(ls1)
ls1 = [ls[i]]
lists.append(ls1)
return lists
def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
) -> List[Document]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = []
id_set = set()
store_len = len(self.index_to_docstore_id)
for j, i in enumerate(indices[0]):
if i == -1 or 0 < self.score_threshold < scores[0][j]:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]):
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j])
docs.append(doc)
continue
id_set.add(i)
docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)):
break_flag = False
for l in [i + k, i - k]:
if 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size:
break_flag = True
break
elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content)
id_set.add(l)
if break_flag:
break
if (not self.chunk_conent) or ("add_context" in doc.metadata and doc.metadata["add_context"] == False):
return docs
if len(id_set) == 0 and self.score_threshold > 0:
return []
id_list = sorted(list(id_set))
id_lists = self.seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
doc = self.docstore.search(_id)
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0)
doc.page_content += " " + doc0.page_content
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
doc.metadata["score"] = int(doc_score)
docs.append(doc)
return docs
from .MyFAISS import MyFAISS
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论