提交 5f38645f 作者: imClumsyPanda

update api.py

上级 2707f58a
...@@ -18,7 +18,7 @@ from chains.local_doc_qa import LocalDocQA ...@@ -18,7 +18,7 @@ from chains.local_doc_qa import LocalDocQA
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH, EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN) VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
from agent import bing_search as agent_bing_search from agent import bing_search
import models.shared as shared import models.shared as shared
from models.loader.args import parser from models.loader.args import parser
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
...@@ -248,7 +248,7 @@ async def local_doc_chat( ...@@ -248,7 +248,7 @@ async def local_doc_chat(
) )
async def chat( async def bing_search_chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"), question: str = Body(..., description="Question", example="工伤保险是什么?"),
history: List[List[str]] = Body( history: List[List[str]] = Body(
[], [],
...@@ -261,10 +261,37 @@ async def chat( ...@@ -261,10 +261,37 @@ async def chat(
], ],
), ),
): ):
for resp, history in local_doc_qa.get_search_result_based_answer(
query=question, chat_history=history, streaming=True
):
pass
source_documents = [
f"""出处 [{inum + 1}] {doc.metadata['source']}:\n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
return ChatMessage(
question=question,
response=resp["result"],
history=history,
source_documents=source_documents,
)
async def chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"),
history: List[List[str]] = Body(
[],
description="History of previous questions and answers",
example=[
[
"工伤保险是什么?",
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
]
],
),
):
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history, for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True): streaming=True):
resp = answer_result.llm_output["answer"] resp = answer_result.llm_output["answer"]
history = answer_result.history history = answer_result.history
pass pass
...@@ -323,22 +350,7 @@ async def document(): ...@@ -323,22 +350,7 @@ async def document():
return RedirectResponse(url="/docs") return RedirectResponse(url="/docs")
async def bing_search(
search_text: str = Query(default=None, description="text you want to search", example="langchain")
):
results = agent_bing_search(search_text)
result_str = ''
for result in results:
for k, v in result.items():
result_str += "%s: %s\n" % (k, v)
result_str += '\n'
return ChatMessage(
question=search_text,
response=result_str,
history=[],
source_documents=[],
)
def api_start(host, port): def api_start(host, port):
...@@ -369,11 +381,10 @@ def api_start(host, port): ...@@ -369,11 +381,10 @@ def api_start(host, port):
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files) app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat) app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs) app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
app.get("/bing_search", response_model=ChatMessage)(bing_search)
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg( local_doc_qa.init_cfg(
llm_model=llm_model_ins, llm_model=llm_model_ins,
...@@ -384,9 +395,7 @@ def api_start(host, port): ...@@ -384,9 +395,7 @@ def api_start(host, port):
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
if __name__ == "__main__": if __name__ == "__main__":
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861) parser.add_argument("--port", type=int, default=7861)
# 初始化消息 # 初始化消息
......
pymupdf pymupdf
paddlepaddle==2.4.2 paddlepaddle==2.4.2
paddleocr paddleocr~=2.6.1.3
langchain==0.0.174 langchain==0.0.174
transformers==4.29.1 transformers==4.29.1
unstructured[local-inference] unstructured[local-inference]
layoutparser[layoutmodels,tesseract] layoutparser[layoutmodels,tesseract]
nltk nltk~=3.8.1
sentence-transformers sentence-transformers
beautifulsoup4 beautifulsoup4
icetk icetk
cpm_kernels cpm_kernels
faiss-cpu faiss-cpu
accelerate accelerate~=0.18.0
gradio==3.28.3 gradio==3.28.3
fastapi fastapi~=0.95.0
uvicorn uvicorn~=0.21.1
peft peft~=0.3.0
pypinyin pypinyin~=0.48.0
click~=8.1.3 click~=8.1.3
tabulate tabulate
azure-core
bitsandbytes; platform_system != "Windows" bitsandbytes; platform_system != "Windows"
llama-cpp-python==0.1.34; platform_system != "Windows" llama-cpp-python==0.1.34; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
torch~=2.0.0
pydantic~=1.10.7
starlette~=0.26.1
numpy~=1.23.5
tqdm~=4.65.0
requests~=2.28.2
tenacity~=8.2.2
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论