Unverified 提交 dfcd6f61 作者: halfss 提交者: GitHub

添加bing搜索agent (#378)

* 1: 基于langchain的bing search agent添加bing搜索支持(百度不靠谱,Google不可用)
2: 调整输入框交互模式,对话/知识库/搜索,三选一

* fixed bug of message have no text

---------

Co-authored-by: root <jv.liu@i1368.com>
上级 537269d5
from .chatglm_with_shared_memory_openai_llm import * #from .chatglm_with_shared_memory_openai_llm import *
\ No newline at end of file from agent.bing_search import search as bing_search
#coding=utf8
import os
from langchain.utilities import BingSearchAPIWrapper
env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY")
env_bing_url = os.environ.get("BING_SEARCH_URL")
def search(text, result_len=3):
if not (env_bing_key and env_bing_url):
return [{"snippet":"please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper()
return search.results(text, result_len)
if __name__ == "__main__":
r = search('python')
...@@ -13,10 +13,12 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -13,10 +13,12 @@ from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated from typing_extensions import Annotated
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from chains.local_doc_qa import LocalDocQA from chains.local_doc_qa import LocalDocQA
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH, EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN) VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
from agent import bing_search as agent_bing_search
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
...@@ -314,6 +316,23 @@ async def document(): ...@@ -314,6 +316,23 @@ async def document():
return RedirectResponse(url="/docs") return RedirectResponse(url="/docs")
async def bing_search(
search_text: str = Query(default=None, description="text you want to search", example="langchain")
):
results = agent_bing_search(search_text)
result_str = ''
for result in results:
for k, v in result.items():
result_str += "%s: %s\n" % (k, v)
result_str += '\n'
return ChatMessage(
question=search_text,
response=result_str,
history=[],
source_documents=[],
)
def api_start(host, port): def api_start(host, port):
global app global app
global local_doc_qa global local_doc_qa
...@@ -342,6 +361,8 @@ def api_start(host, port): ...@@ -342,6 +361,8 @@ def api_start(host, port):
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs) app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
app.get("/bing_search", response_model=ChatMessage)(bing_search)
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg( local_doc_qa.init_cfg(
llm_model=LLM_MODEL, llm_model=LLM_MODEL,
......
...@@ -24,7 +24,14 @@ export const getfilelist = (knowledge_base_id: any) => { ...@@ -24,7 +24,14 @@ export const getfilelist = (knowledge_base_id: any) => {
}) })
} }
export const bing_search = (search_text: any) => {
return api({
url: '/bing_search',
method: 'get',
params: { search_text },
})
}
export const deletefile = (params: any) => { export const deletefile = (params: any) => {
return api({ return api({
url: '/local_doc_qa/delete_file', url: '/local_doc_qa/delete_file',
......
...@@ -3,7 +3,7 @@ import type { Ref } from 'vue' ...@@ -3,7 +3,7 @@ import type { Ref } from 'vue'
import { computed, onMounted, onUnmounted, ref } from 'vue' import { computed, onMounted, onUnmounted, ref } from 'vue'
import { useRoute } from 'vue-router' import { useRoute } from 'vue-router'
import { storeToRefs } from 'pinia' import { storeToRefs } from 'pinia'
import { NAutoComplete, NButton, NInput, NSwitch, useDialog, useMessage } from 'naive-ui' import { NAutoComplete, NButton, NInput, NRadioButton, NRadioGroup, useDialog, useMessage } from 'naive-ui'
import html2canvas from 'html2canvas' import html2canvas from 'html2canvas'
import { Message } from './components' import { Message } from './components'
import { useScroll } from './hooks/useScroll' import { useScroll } from './hooks/useScroll'
...@@ -14,7 +14,7 @@ import { HoverButton, SvgIcon } from '@/components/common' ...@@ -14,7 +14,7 @@ import { HoverButton, SvgIcon } from '@/components/common'
import { useBasicLayout } from '@/hooks/useBasicLayout' import { useBasicLayout } from '@/hooks/useBasicLayout'
import { useChatStore, usePromptStore } from '@/store' import { useChatStore, usePromptStore } from '@/store'
import { t } from '@/locales' import { t } from '@/locales'
import { chat, chatfile } from '@/api/chat' import { bing_search, chat, chatfile } from '@/api/chat'
import { idStore } from '@/store/modules/knowledgebaseid/id' import { idStore } from '@/store/modules/knowledgebaseid/id'
let controller = new AbortController() let controller = new AbortController()
...@@ -39,6 +39,7 @@ const conversationList = computed(() => dataSources.value.filter(item => (!item. ...@@ -39,6 +39,7 @@ const conversationList = computed(() => dataSources.value.filter(item => (!item.
const prompt = ref<string>('') const prompt = ref<string>('')
const loading = ref<boolean>(false) const loading = ref<boolean>(false)
const inputRef = ref<Ref | null>(null) const inputRef = ref<Ref | null>(null)
const search = ref<boolean>('对话')
// 添加PromptStore // 添加PromptStore
const promptStore = usePromptStore() const promptStore = usePromptStore()
...@@ -55,16 +56,71 @@ dataSources.value.forEach((item, index) => { ...@@ -55,16 +56,71 @@ dataSources.value.forEach((item, index) => {
updateChatSome(+uuid, index, { loading: false }) updateChatSome(+uuid, index, { loading: false })
}) })
function handleSubmit() { async function handleSubmit() {
if (search.value == 'Bing搜索') {
loading.value = true
const options: Chat.ConversationRequest = {}
const lastText = ''
const message = prompt.value
addChat(
+uuid,
{
dateTime: new Date().toLocaleString(),
text: message,
inversion: true,
error: false,
conversationOptions: null,
requestOptions: { prompt: message, options: null },
},
)
scrollToBottom()
const res = await bing_search(prompt.value)
const result = active.value ? `${res.data.response}\n\n数据来源:\n\n>${res.data.source_documents.join('>')}` : res.data.response
addChat(
+uuid,
{
dateTime: new Date().toLocaleString(),
text: '',
loading: true,
inversion: false,
error: false,
conversationOptions: null,
requestOptions: { prompt: message, options: { ...options } },
},
)
scrollToBottom()
updateChat(
+uuid,
dataSources.value.length - 1,
{
dateTime: new Date().toLocaleString(),
text: lastText + (result ?? ''),
inversion: false,
error: false,
loading: false,
conversationOptions: null,
requestOptions: { prompt: message, options: { ...options } },
},
)
prompt.value = ''
scrollToBottomIfAtBottom()
loading.value = false
}
else {
onConversation() onConversation()
}
} }
async function onConversation() { async function onConversation() {
const message = prompt.value const message = prompt.value
if (usingContext.value) { if (usingContext.value) {
for (let i = 0; i < dataSources.value.length; i = i + 2) for (let i = 0; i < dataSources.value.length; i = i + 2) {
if (!i)
history.value.push([dataSources.value[i].text, dataSources.value[i + 1].text.split('\n\n数据来源:\n\n>')[0]]) history.value.push([dataSources.value[i].text, dataSources.value[i + 1].text.split('\n\n数据来源:\n\n>')[0]])
} }
}
else { history.value.length = 0 } else { history.value.length = 0 }
if (loading.value) if (loading.value)
...@@ -480,6 +536,13 @@ onUnmounted(() => { ...@@ -480,6 +536,13 @@ onUnmounted(() => {
if (loading.value) if (loading.value)
controller.abort() controller.abort()
}) })
function searchfun() {
if (search.value == '知识库')
active.value = true
else
active.value = false
}
</script> </script>
<template> <template>
...@@ -532,14 +595,11 @@ onUnmounted(() => { ...@@ -532,14 +595,11 @@ onUnmounted(() => {
<footer :class="footerClass"> <footer :class="footerClass">
<div class="w-full max-w-screen-xl m-auto"> <div class="w-full max-w-screen-xl m-auto">
<div class="flex items-center justify-between space-x-2"> <div class="flex items-center justify-between space-x-2">
<NSwitch v-model:value="active"> <NRadioGroup v-model:value="search" @change="searchfun">
<template #checked> <NRadioButton value="对话" label="对话" />
知识库 <NRadioButton value="知识库" label="知识库" />
</template> <NRadioButton value="Bing搜索" label="Bing搜索" />
<template #unchecked> </NRadioGroup>
知识库&nbsp;&nbsp;
</template>
</NSwitch>
<HoverButton @click="handleClear"> <HoverButton @click="handleClear">
<span class="text-xl text-[#4f555e] dark:text-white"> <span class="text-xl text-[#4f555e] dark:text-white">
<SvgIcon icon="ri:delete-bin-line" /> <SvgIcon icon="ri:delete-bin-line" />
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论