提交 5f38645f 作者: imClumsyPanda

update api.py

上级 2707f58a
......@@ -18,7 +18,7 @@ from chains.local_doc_qa import LocalDocQA
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
from agent import bing_search as agent_bing_search
from agent import bing_search
import models.shared as shared
from models.loader.args import parser
from models.loader import LoaderCheckPoint
......@@ -248,7 +248,7 @@ async def local_doc_chat(
)
async def chat(
async def bing_search_chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"),
history: List[List[str]] = Body(
[],
......@@ -261,10 +261,37 @@ async def chat(
],
),
):
for resp, history in local_doc_qa.get_search_result_based_answer(
query=question, chat_history=history, streaming=True
):
pass
source_documents = [
f"""出处 [{inum + 1}] {doc.metadata['source']}:\n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
return ChatMessage(
question=question,
response=resp["result"],
history=history,
source_documents=source_documents,
)
async def chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"),
history: List[List[str]] = Body(
[],
description="History of previous questions and answers",
example=[
[
"工伤保险是什么?",
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
]
],
),
):
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
......@@ -323,22 +350,7 @@ async def document():
return RedirectResponse(url="/docs")
async def bing_search(
search_text: str = Query(default=None, description="text you want to search", example="langchain")
):
results = agent_bing_search(search_text)
result_str = ''
for result in results:
for k, v in result.items():
result_str += "%s: %s\n" % (k, v)
result_str += '\n'
return ChatMessage(
question=search_text,
response=result_str,
history=[],
source_documents=[],
)
def api_start(host, port):
......@@ -369,11 +381,10 @@ def api_start(host, port):
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
app.get("/bing_search", response_model=ChatMessage)(bing_search)
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(
llm_model=llm_model_ins,
......@@ -384,9 +395,7 @@ def api_start(host, port):
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
# 初始化消息
......
pymupdf
paddlepaddle==2.4.2
paddleocr
paddleocr~=2.6.1.3
langchain==0.0.174
transformers==4.29.1
unstructured[local-inference]
layoutparser[layoutmodels,tesseract]
nltk
nltk~=3.8.1
sentence-transformers
beautifulsoup4
icetk
cpm_kernels
faiss-cpu
accelerate
accelerate~=0.18.0
gradio==3.28.3
fastapi
uvicorn
peft
pypinyin
fastapi~=0.95.0
uvicorn~=0.21.1
peft~=0.3.0
pypinyin~=0.48.0
click~=8.1.3
tabulate
azure-core
bitsandbytes; platform_system != "Windows"
llama-cpp-python==0.1.34; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
torch~=2.0.0
pydantic~=1.10.7
starlette~=0.26.1
numpy~=1.23.5
tqdm~=4.65.0
requests~=2.28.2
tenacity~=8.2.2
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论