提交 7863e0fe 作者: imClumsyPanda

updata MyFAISS

上级 27a9bf24
......@@ -4,9 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLo
from configs.model_config import *
import datetime
from textsplitter import ChineseTextSplitter
from typing import List, Tuple, Dict
from langchain.docstore.document import Document
import numpy as np
from typing import List
from utils import torch_gc
from tqdm import tqdm
from pypinyin import lazy_pinyin
......
......@@ -22,27 +22,6 @@ class MyFAISS(FAISS, VectorStore):
index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2)
# def similarity_search_with_score_by_vector(
# self, embedding: List[float], k: int = 4
# ) -> List[Tuple[Document, float]]:
# faiss = dependable_faiss_import()
# vector = np.array([embedding], dtype=np.float32)
# if self._normalize_L2:
# faiss.normalize_L2(vector)
# scores, indices = self.index.search(vector, k)
# docs = []
# for j, i in enumerate(indices[0]):
# if i == -1:
# # This happens when not enough docs are returned.
# continue
# _id = self.index_to_docstore_id[i]
# doc = self.docstore.search(_id)
# if not isinstance(doc, Document):
# raise ValueError(f"Could not find document for id {_id}, got {doc}")
#
# docs.append((doc, scores[0][j]))
# return docs
def seperate_list(self, ls: List[int]) -> List[List[int]]:
# TODO: 增加是否属于同一文档的判断
lists = []
......@@ -59,7 +38,11 @@ class MyFAISS(FAISS, VectorStore):
def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
) -> List[Document]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
faiss = dependable_faiss_import()
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
docs = []
id_set = set()
store_len = len(self.index_to_docstore_id)
......@@ -69,7 +52,7 @@ class MyFAISS(FAISS, VectorStore):
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]):
if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j])
......@@ -79,11 +62,17 @@ class MyFAISS(FAISS, VectorStore):
docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)):
break_flag = False
for l in [i + k, i - k]:
if 0 <= l < len(self.index_to_docstore_id):
if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward":
expand_range = [i + k]
elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward":
expand_range = [i - k]
else:
expand_range = [i + k, i - k]
for l in expand_range:
if l not in id_set and 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size:
if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != doc.metadata["source"]:
break_flag = True
break
elif doc0.metadata["source"] == doc.metadata["source"]:
......@@ -91,7 +80,7 @@ class MyFAISS(FAISS, VectorStore):
id_set.add(l)
if break_flag:
break
if (not self.chunk_conent) or ("add_context" in doc.metadata and doc.metadata["add_context"] == False):
if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]):
return docs
if len(id_set) == 0 and self.score_threshold > 0:
return []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论