Unverified 提交 ff5f73e0 作者: NieLamu 提交者: GitHub

feat: fastapi 接口优化 (#684)

1. 接口增加参数校验,防止攻击
2. 优化接口参数和逻辑
3. 规范接口错误响应
4. 增加接口描述

Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
上级 a5ca4bf2
......@@ -79,23 +79,37 @@ class ChatMessage(BaseModel):
}
def get_folder_path(local_doc_id: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "content")
def get_kb_path(local_doc_id: str):
return os.path.join(KB_ROOT_PATH, local_doc_id)
def get_doc_path(local_doc_id: str):
return os.path.join(get_kb_path(local_doc_id), "content")
def get_vs_path(local_doc_id: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store")
return os.path.join(get_kb_path(local_doc_id), "vector_store")
def get_file_path(local_doc_id: str, doc_name: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name)
return os.path.join(get_doc_path(local_doc_id), doc_name)
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
async def upload_file(
file: UploadFile = File(description="A single binary file"),
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
):
saved_path = get_folder_path(knowledge_base_id)
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me", data=[])
saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path):
os.makedirs(saved_path)
......@@ -125,21 +139,25 @@ async def upload_files(
],
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
):
saved_path = get_folder_path(knowledge_base_id)
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me", data=[])
saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path):
os.makedirs(saved_path)
filelist = []
for file in files:
file_content = ''
file_path = os.path.join(saved_path, file.filename)
file_content = file.file.read()
file_content = await file.read()
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
continue
with open(file_path, "ab+") as f:
with open(file_path, "wb") as f:
f.write(file_content)
filelist.append(file_path)
if filelist:
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
vs_path = get_vs_path(knowledge_base_id)
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
if len(loaded_files):
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success"
return BaseResponse(code=200, msg=file_status)
......@@ -163,16 +181,24 @@ async def list_kbs():
async def list_docs(
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1")
knowledge_base_id: str = Query(..., description="Knowledge Base Name", example="kb1")
):
local_doc_folder = get_folder_path(knowledge_base_id)
if not validate_kb_name(knowledge_base_id):
return ListDocsResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
kb_path = get_kb_path(knowledge_base_id)
local_doc_folder = get_doc_path(knowledge_base_id)
if not os.path.exists(kb_path):
return ListDocsResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found", data=[])
if not os.path.exists(local_doc_folder):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
all_doc_names = [
doc
for doc in os.listdir(local_doc_folder)
if os.path.isfile(os.path.join(local_doc_folder, doc))
]
all_doc_names = []
else:
all_doc_names = [
doc
for doc in os.listdir(local_doc_folder)
if os.path.isfile(os.path.join(local_doc_folder, doc))
]
return ListDocsResponse(data=all_doc_names)
......@@ -181,11 +207,15 @@ async def delete_kb(
description="Knowledge Base Name",
example="kb1"),
):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
# TODO: 确认是否支持批量删除知识库
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
shutil.rmtree(get_folder_path(knowledge_base_id))
kb_path = get_kb_path(knowledge_base_id)
if not os.path.exists(kb_path):
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
shutil.rmtree(kb_path)
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
......@@ -194,27 +224,30 @@ async def delete_doc(
description="Knowledge Base Name",
example="kb1"),
doc_name: str = Query(
None, description="doc name", example="doc_name_1.pdf"
..., description="doc name", example="doc_name_1.pdf"
),
):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
if not os.path.exists(get_kb_path(knowledge_base_id)):
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
doc_path = get_file_path(knowledge_base_id, doc_name)
if os.path.exists(doc_path):
os.remove(doc_path)
remain_docs = await list_docs(knowledge_base_id)
if len(remain_docs.data) == 0:
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
shutil.rmtree(get_kb_path(knowledge_base_id), ignore_errors=True)
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else:
status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
if "success" in status:
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else:
return BaseResponse(code=1, msg=f"document {doc_name} delete fail")
return BaseResponse(code=500, msg=f"document {doc_name} delete fail")
else:
return BaseResponse(code=1, msg=f"document {doc_name} not found")
return BaseResponse(code=404, msg=f"document {doc_name} not found")
async def update_doc(
......@@ -222,23 +255,26 @@ async def update_doc(
description="知识库名",
example="kb1"),
old_doc: str = Query(
None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
..., description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
),
new_doc: UploadFile = File(description="待上传文件"),
):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
if not os.path.exists(get_kb_path(knowledge_base_id)):
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
doc_path = get_file_path(knowledge_base_id, old_doc)
if not os.path.exists(doc_path):
return BaseResponse(code=1, msg=f"document {old_doc} not found")
return BaseResponse(code=404, msg=f"document {old_doc} not found")
else:
os.remove(doc_path)
delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
if "fail" in delete_status:
return BaseResponse(code=1, msg=f"document {old_doc} delete failed")
return BaseResponse(code=500, msg=f"document {old_doc} delete failed")
else:
saved_path = get_folder_path(knowledge_base_id)
saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path):
os.makedirs(saved_path)
......@@ -279,7 +315,7 @@ async def local_doc_chat(
):
vs_path = get_vs_path(knowledge_base_id)
if not os.path.exists(vs_path):
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found")
# return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
return ChatMessage(
question=question,
response=f"Knowledge base {knowledge_base_id} not found",
......@@ -467,7 +503,7 @@ def api_start(host, port, **kwargs):
# 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id
app.websocket("/local_doc_qa/stream_chat")(stream_chat)
app.get("/", response_model=BaseResponse)(document)
app.get("/", response_model=BaseResponse, summary="swagger 文档")(document)
# 增加基于bing搜索的流式问答
# 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia
......@@ -475,17 +511,17 @@ def api_start(host, port, **kwargs):
# 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing
app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing)
app.post("/chat", response_model=ChatMessage)(chat)
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc)
app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc)
app.post("/chat", response_model=ChatMessage, summary="与模型对话")(chat)
app.post("/local_doc_qa/upload_file", response_model=BaseResponse, summary="上传文件到知识库")(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse, summary="批量上传文件到知识库")(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage, summary="与知识库对话")(local_doc_chat)
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage, summary="与必应搜索对话")(bing_search_chat)
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse, summary="获取知识库列表")(list_kbs)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs)
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc)
app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc)
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论