提交 9c422cc6 作者: imClumsyPanda

update bing_search.py

上级 f986b756
#coding=utf8 #coding=utf8
import os
from langchain.utilities import BingSearchAPIWrapper from langchain.utilities import BingSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY") def bing_search(text, result_len=3):
env_bing_url = os.environ.get("BING_SEARCH_URL") if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env inof not fould",
def search(text, result_len=3): "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
if not (env_bing_key and env_bing_url): search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
return [{"snippet":"please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", bing_search_url=BING_SEARCH_URL)
"title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper()
return search.results(text, result_len) return search.results(text, result_len)
if __name__ == "__main__": if __name__ == "__main__":
r = search('python') r = bing_search('python')
print(r)
...@@ -4,7 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader ...@@ -4,7 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader
from configs.model_config import * from configs.model_config import *
import datetime import datetime
from textsplitter import ChineseTextSplitter from textsplitter import ChineseTextSplitter
from typing import List, Tuple from typing import List, Tuple, Dict
from langchain.docstore.document import Document from langchain.docstore.document import Document
import numpy as np import numpy as np
from utils import torch_gc from utils import torch_gc
...@@ -18,6 +18,8 @@ from models.base import (BaseAnswer, ...@@ -18,6 +18,8 @@ from models.base import (BaseAnswer,
from models.loader.args import parser from models.loader.args import parser
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
import models.shared as shared import models.shared as shared
from agent import bing_search
from langchain.docstore.document import Document
def load_file(filepath, sentence_size=SENTENCE_SIZE): def load_file(filepath, sentence_size=SENTENCE_SIZE):
...@@ -58,8 +60,9 @@ def write_check_file(filepath, docs): ...@@ -58,8 +60,9 @@ def write_check_file(filepath, docs):
fout.close() fout.close()
def generate_prompt(related_docs: List[str], query: str, def generate_prompt(related_docs: List[str],
prompt_template=PROMPT_TEMPLATE) -> str: query: str,
prompt_template: str = PROMPT_TEMPLATE, ) -> str:
context = "\n".join([doc.page_content for doc in related_docs]) context = "\n".join([doc.page_content for doc in related_docs])
prompt = prompt_template.replace("{question}", query).replace("{context}", context) prompt = prompt_template.replace("{question}", query).replace("{context}", context)
return prompt return prompt
...@@ -137,6 +140,16 @@ def similarity_search_with_score_by_vector( ...@@ -137,6 +140,16 @@ def similarity_search_with_score_by_vector(
return docs return docs
def search_result2docs(search_results):
docs = []
for result in search_results:
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={"source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else ""})
docs.append(doc)
return docs
class LocalDocQA: class LocalDocQA:
llm: BaseAnswer = None llm: BaseAnswer = None
embeddings: object = None embeddings: object = None
...@@ -262,7 +275,6 @@ class LocalDocQA: ...@@ -262,7 +275,6 @@ class LocalDocQA:
"source_documents": related_docs_with_score} "source_documents": related_docs_with_score}
yield response, history yield response, history
# query 查询内容 # query 查询内容
# vs_path 知识库路径 # vs_path 知识库路径
# chunk_conent 是否启用上下文关联 # chunk_conent 是否启用上下文关联
...@@ -288,11 +300,26 @@ class LocalDocQA: ...@@ -288,11 +300,26 @@ class LocalDocQA:
"source_documents": related_docs_with_score} "source_documents": related_docs_with_score}
return response, prompt return response, prompt
def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING):
results = bing_search(query)
result_docs = search_result2docs(results)
prompt = generate_prompt(result_docs, query)
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
response = {"query": query,
"result": resp,
"source_documents": result_docs}
yield response, history
if __name__ == "__main__": if __name__ == "__main__":
# 初始化消息 # 初始化消息
args = None args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model']) args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
args_dict = vars(args) args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict) shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
...@@ -304,13 +331,17 @@ if __name__ == "__main__": ...@@ -304,13 +331,17 @@ if __name__ == "__main__":
query = "本项目使用的embedding模型是什么,消耗多少显存" query = "本项目使用的embedding模型是什么,消耗多少显存"
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test" vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
last_print_len = 0 last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path, # vs_path=vs_path,
chat_history=[], # chat_history=[],
streaming=True): # streaming=True):
logger.info(resp["result"][last_print_len:], end="", flush=True) for resp, history in local_doc_qa.get_search_result_based_answer(query=query,
chat_history=[],
streaming=True):
print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"]) last_print_len = len(resp["result"])
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http")
else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in for inum, doc in
enumerate(resp["source_documents"])] enumerate(resp["source_documents"])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论