提交 7863e0fe 作者: imClumsyPanda

updata MyFAISS

上级 27a9bf24
...@@ -4,9 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLo ...@@ -4,9 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLo
from configs.model_config import * from configs.model_config import *
import datetime import datetime
from textsplitter import ChineseTextSplitter from textsplitter import ChineseTextSplitter
from typing import List, Tuple, Dict from typing import List
from langchain.docstore.document import Document
import numpy as np
from utils import torch_gc from utils import torch_gc
from tqdm import tqdm from tqdm import tqdm
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
......
...@@ -22,27 +22,6 @@ class MyFAISS(FAISS, VectorStore): ...@@ -22,27 +22,6 @@ class MyFAISS(FAISS, VectorStore):
index_to_docstore_id=index_to_docstore_id, index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2) normalize_L2=normalize_L2)
# def similarity_search_with_score_by_vector(
# self, embedding: List[float], k: int = 4
# ) -> List[Tuple[Document, float]]:
# faiss = dependable_faiss_import()
# vector = np.array([embedding], dtype=np.float32)
# if self._normalize_L2:
# faiss.normalize_L2(vector)
# scores, indices = self.index.search(vector, k)
# docs = []
# for j, i in enumerate(indices[0]):
# if i == -1:
# # This happens when not enough docs are returned.
# continue
# _id = self.index_to_docstore_id[i]
# doc = self.docstore.search(_id)
# if not isinstance(doc, Document):
# raise ValueError(f"Could not find document for id {_id}, got {doc}")
#
# docs.append((doc, scores[0][j]))
# return docs
def seperate_list(self, ls: List[int]) -> List[List[int]]: def seperate_list(self, ls: List[int]) -> List[List[int]]:
# TODO: 增加是否属于同一文档的判断 # TODO: 增加是否属于同一文档的判断
lists = [] lists = []
...@@ -59,7 +38,11 @@ class MyFAISS(FAISS, VectorStore): ...@@ -59,7 +38,11 @@ class MyFAISS(FAISS, VectorStore):
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4 self, embedding: List[float], k: int = 4
) -> List[Document]: ) -> List[Document]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) faiss = dependable_faiss_import()
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
docs = [] docs = []
id_set = set() id_set = set()
store_len = len(self.index_to_docstore_id) store_len = len(self.index_to_docstore_id)
...@@ -69,7 +52,7 @@ class MyFAISS(FAISS, VectorStore): ...@@ -69,7 +52,7 @@ class MyFAISS(FAISS, VectorStore):
continue continue
_id = self.index_to_docstore_id[i] _id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id) doc = self.docstore.search(_id)
if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]): if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
if not isinstance(doc, Document): if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}") raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j]) doc.metadata["score"] = int(scores[0][j])
...@@ -79,11 +62,17 @@ class MyFAISS(FAISS, VectorStore): ...@@ -79,11 +62,17 @@ class MyFAISS(FAISS, VectorStore):
docs_len = len(doc.page_content) docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)): for k in range(1, max(i, store_len - i)):
break_flag = False break_flag = False
for l in [i + k, i - k]: if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward":
if 0 <= l < len(self.index_to_docstore_id): expand_range = [i + k]
elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward":
expand_range = [i - k]
else:
expand_range = [i + k, i - k]
for l in expand_range:
if l not in id_set and 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l] _id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0) doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size: if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != doc.metadata["source"]:
break_flag = True break_flag = True
break break
elif doc0.metadata["source"] == doc.metadata["source"]: elif doc0.metadata["source"] == doc.metadata["source"]:
...@@ -91,7 +80,7 @@ class MyFAISS(FAISS, VectorStore): ...@@ -91,7 +80,7 @@ class MyFAISS(FAISS, VectorStore):
id_set.add(l) id_set.add(l)
if break_flag: if break_flag:
break break
if (not self.chunk_conent) or ("add_context" in doc.metadata and doc.metadata["add_context"] == False): if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]):
return docs return docs
if len(id_set) == 0 and self.score_threshold > 0: if len(id_set) == 0 and self.score_threshold > 0:
return [] return []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论