提交 9c422cc6 作者: imClumsyPanda

update bing_search.py

上级 f986b756
#coding=utf8
import os
from langchain.utilities import BingSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY")
env_bing_url = os.environ.get("BING_SEARCH_URL")
def search(text, result_len=3):
if not (env_bing_key and env_bing_url):
return [{"snippet":"please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper()
def bing_search(text, result_len=3):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env inof not fould",
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
bing_search_url=BING_SEARCH_URL)
return search.results(text, result_len)
if __name__ == "__main__":
r = search('python')
r = bing_search('python')
print(r)
......@@ -4,7 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader
from configs.model_config import *
import datetime
from textsplitter import ChineseTextSplitter
from typing import List, Tuple
from typing import List, Tuple, Dict
from langchain.docstore.document import Document
import numpy as np
from utils import torch_gc
......@@ -18,6 +18,8 @@ from models.base import (BaseAnswer,
from models.loader.args import parser
from models.loader import LoaderCheckPoint
import models.shared as shared
from agent import bing_search
from langchain.docstore.document import Document
def load_file(filepath, sentence_size=SENTENCE_SIZE):
......@@ -58,8 +60,9 @@ def write_check_file(filepath, docs):
fout.close()
def generate_prompt(related_docs: List[str], query: str,
prompt_template=PROMPT_TEMPLATE) -> str:
def generate_prompt(related_docs: List[str],
query: str,
prompt_template: str = PROMPT_TEMPLATE, ) -> str:
context = "\n".join([doc.page_content for doc in related_docs])
prompt = prompt_template.replace("{question}", query).replace("{context}", context)
return prompt
......@@ -137,6 +140,16 @@ def similarity_search_with_score_by_vector(
return docs
def search_result2docs(search_results):
docs = []
for result in search_results:
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={"source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else ""})
docs.append(doc)
return docs
class LocalDocQA:
llm: BaseAnswer = None
embeddings: object = None
......@@ -262,7 +275,6 @@ class LocalDocQA:
"source_documents": related_docs_with_score}
yield response, history
# query 查询内容
# vs_path 知识库路径
# chunk_conent 是否启用上下文关联
......@@ -288,11 +300,26 @@ class LocalDocQA:
"source_documents": related_docs_with_score}
return response, prompt
def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING):
results = bing_search(query)
result_docs = search_result2docs(results)
prompt = generate_prompt(result_docs, query)
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
response = {"query": query,
"result": resp,
"source_documents": result_docs}
yield response, history
if __name__ == "__main__":
# 初始化消息
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
......@@ -304,13 +331,17 @@ if __name__ == "__main__":
query = "本项目使用的embedding模型是什么,消耗多少显存"
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=[],
streaming=True):
logger.info(resp["result"][last_print_len:], end="", flush=True)
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
# vs_path=vs_path,
# chat_history=[],
# streaming=True):
for resp, history in local_doc_qa.get_search_result_based_answer(query=query,
chat_history=[],
streaming=True):
print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"])
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http")
else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in
enumerate(resp["source_documents"])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论