提交 059fe828 作者: wangxinkai 提交者: imClumsyPanda

feat: 添加mmr相似度搜索,支持返回相似度分数

上级 b3f83060
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings # from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS from chains.lib.embeddings import MyEmbeddings
# from langchain.vectorstores import FAISS
from chains.lib.vectorstores import FAISSVS
from langchain.document_loaders import UnstructuredFileLoader from langchain.document_loaders import UnstructuredFileLoader
from models.chatglm_llm import ChatGLM from models.chatglm_llm import ChatGLM
import sentence_transformers import sentence_transformers
...@@ -50,7 +52,7 @@ class LocalDocQA: ...@@ -50,7 +52,7 @@ class LocalDocQA:
use_ptuning_v2=use_ptuning_v2) use_ptuning_v2=use_ptuning_v2)
self.llm.history_len = llm_history_len self.llm.history_len = llm_history_len
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model],
model_kwargs={'device': embedding_device}) model_kwargs={'device': embedding_device})
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
# device=embedding_device) # device=embedding_device)
...@@ -97,12 +99,12 @@ class LocalDocQA: ...@@ -97,12 +99,12 @@ class LocalDocQA:
print(f"{file} 未能成功加载") print(f"{file} 未能成功加载")
if len(docs) > 0: if len(docs) > 0:
if vs_path and os.path.isdir(vs_path): if vs_path and os.path.isdir(vs_path):
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISSVS.load_local(vs_path, self.embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
else: else:
if not vs_path: if not vs_path:
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store = FAISS.from_documents(docs, self.embeddings) vector_store = FAISSVS.from_documents(docs, self.embeddings)
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
return vs_path, loaded_files return vs_path, loaded_files
...@@ -127,7 +129,7 @@ class LocalDocQA: ...@@ -127,7 +129,7 @@ class LocalDocQA:
input_variables=["context", "question"] input_variables=["context", "question"]
) )
self.llm.history = chat_history self.llm.history = chat_history
vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store = FAISSVS.load_local(vs_path, self.embeddings)
knowledge_chain = RetrievalQA.from_llm( knowledge_chain = RetrievalQA.from_llm(
llm=self.llm, llm=self.llm,
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.prompts import PromptTemplate\n",
"from lib.embeds import MyEmbeddings\n",
"from lib.faiss import FAISSVS\n",
"from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
"from langchain.chains.llm import LLMChain\n",
"from lib.chatglm_llm import ChatGLM, AlpacaGLM\n",
"from lib.config import *\n",
"from lib.utils import get_docs\n",
"\n",
"\n",
"class LocalDocQA:\n",
" def __init__(self, \n",
" embedding_model=EMBEDDING_MODEL, \n",
" embedding_device=EMBEDDING_DEVICE, \n",
" llm_model=LLM_MODEL, \n",
" llm_device=LLM_DEVICE, \n",
" llm_history_len=LLM_HISTORY_LEN, \n",
" top_k=VECTOR_SEARCH_TOP_K,\n",
" vs_name = VS_NAME\n",
" ) -> None:\n",
" \n",
" torch.cuda.empty_cache()\n",
" torch.cuda.empty_cache()\n",
"\n",
" self.embedding_model = embedding_model\n",
" self.llm_model = llm_model\n",
" self.embedding_device = embedding_device\n",
" self.llm_device = llm_device\n",
" self.llm_history_len = llm_history_len\n",
" self.top_k = top_k\n",
" self.vs_name = vs_name\n",
"\n",
" self.llm = AlpacaGLM()\n",
" self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)\n",
"\n",
" self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])\n",
" self.load_vector_store(vs_name)\n",
"\n",
" self.prompt = PromptTemplate(\n",
" template=PROMPT_TEMPLATE,\n",
" input_variables=[\"context\", \"question\"]\n",
" )\n",
" self.search_params = {\n",
" \"engine\": \"bing\",\n",
" \"gl\": \"us\",\n",
" \"hl\": \"en\",\n",
" \"serpapi_api_key\": \"\"\n",
" }\n",
"\n",
" def init_knowledge_vector_store(self, vs_name: str):\n",
" \n",
" docs = get_docs(KNOWLEDGE_PATH)\n",
" vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
" vs_path = VECTORSTORE_PATH + vs_name\n",
" vector_store.save_local(vs_path)\n",
"\n",
" def add_knowledge_to_vector_store(self, vs_name: str):\n",
" docs = get_docs(ADD_KNOWLEDGE_PATH)\n",
" new_vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
" vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings) \n",
" vector_store.merge_from(new_vector_store)\n",
" vector_store.save_local(VECTORSTORE_PATH + vs_name)\n",
"\n",
" def load_vector_store(self, vs_name: str):\n",
" self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)\n",
"\n",
" # def get_search_based_answer(self, query):\n",
" \n",
" # search = SerpAPIWrapper(params=self.search_params)\n",
" # docs = search.run(query)\n",
" # search_chain = load_qa_chain(self.llm, chain_type=\"stuff\")\n",
" # answer = search_chain.run(input_documents=docs, question=query)\n",
"\n",
" # return answer\n",
" \n",
" def get_knowledge_based_answer(self, query):\n",
" \n",
" docs = self.vector_store.max_marginal_relevance_search(query)\n",
" print(f'召回的文档和相似度分数:{docs}')\n",
" # 这里 doc[1] 就是对应的score \n",
" docs = [doc[0] for doc in docs]\n",
" \n",
" document_prompt = PromptTemplate(\n",
" input_variables=[\"page_content\"], template=\"Context:\\n{page_content}\"\n",
" )\n",
" llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)\n",
" combine_documents_chain = StuffDocumentsChain(\n",
" llm_chain=llm_chain,\n",
" document_variable_name=\"context\",\n",
" document_prompt=document_prompt,\n",
" )\n",
" answer = combine_documents_chain.run(\n",
" input_documents=docs, question=query\n",
" )\n",
"\n",
" self.llm.history[-1][0] = query\n",
" self.llm.history[-1][-1] = answer\n",
" return answer, docs, self.llm.history"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d4342213010c4ed2ad5b04694aa436d6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"qa = LocalDocQA()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"召回的文档和相似度分数:[(Document(page_content='****** LOGI APT Group Intelligence Research Yearbook APT Knowledge Graph APT组织情报 研究年鉴', metadata={'source': './KnowledgeStore/APT group Intelligence Research handbook-2022.pdf', 'page': 0}), 0.45381865), (Document(page_content='9 MANDIANT APT42: Crooked Charms, Cons and Compromises FIGURE 8. APT42 impersonates University of Oxford vaccinologist. APT42 Credential harvesting page masquerading as a Yahoo login portal.', metadata={'source': './KnowledgeStore/APT42_Crooked_Charms_Cons_and_Compromises.pdf', 'page': 8}), 0.4535672), (Document(page_content='The origin story of APT32 macros T H R E A T R E S E A R C H R E P O R T R u n n i n g t h r o u g h a l l t h e S U O f i l e s t r u c t u r e s i s l a b o r i o u s a n d d i d n ’ t y i e l d m u c h m o r e t h a n a s t r i n g d u m p w o u l d h a v e d o n e a n y w a y . W e f i n d p a t h s t o s o u r c e c o d e f i l e s , p r o j e c t n a m e s , e t c . W e c a n i n f e r f r o m t h e m y r i a d o f r e f e r e n c e s i n XmlPackageOptions , O u t l i n i n g S t a t e D i r , e t c . , t h a t t h e HtaDotnet a n d ShellcodeLoader s o l u t i o n s w e r e o r i g i n a l l y u n d e r t h e f o l d e r p a t h G:\\\\WebBuilder\\\\Gift_HtaDotnet\\\\ . T h i s i s a l s o s u p p o r t e d b y t h e P D B p a t h s o f o l d e r b u i l t b i n a r i e s w i t h i n t h e b r o a d e r S t r i k e S u i t G i f t p a c k a g e . F r o m l o o k i n g a t D e b u g g e r W a t c h e s v a l u e s i n o t h e r p r o j e c t s , w e c a n s e e t h a t t h e m a l w a r e d e v e l o p e r w a s a c t i v e l y d e b u g g i n g t h e h i s t o r i c a l p r o g r a m s . S U O f i l e D e b u g g e r W a t c h e s WebBuilder/HtaDotNet/HtaDotnet.v11.suo result WebBuilder/ShellcodeLoader/.vs/L/v14/.suo (char)77 WebBuilder/ShellcodeLoader/L.suo (char)77 3 4 04/2022', metadata={'source': './KnowledgeStore/Stairwell-threat-report-The-origin-of-APT32-macros.pdf', 'page': 33}), 0.38091612), (Document(page_content='2 APTs and COVID-19: How advanced persistent threats use the coronavirus as a lureTable of contents Introduction: APT groups using COVID-19 .........................................................', metadata={'source': './KnowledgeStore/200407-MWB-COVID-White-Paper_Final.pdf', 'page': 1}), 0.44476452)]\n"
]
}
],
"source": [
"query = r\"\"\"make a brief introduction of APT?\"\"\"\n",
"ans, docs, _ = qa.get_knowledge_based_answer(query)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\nAnswer: APT stands for Advanced Persistent Threat, which is a type of malicious cyberattack that is carried out by a sophisticated hacker group or state-sponsored organization. APTs are designed to remain undetected for a long period of time and are often used to steal sensitive data or disrupt critical infrastructure.'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ans"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chatgpt",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论