Unverified 提交 ff5f73e0 作者: NieLamu 提交者: GitHub

feat: fastapi 接口优化 (#684)

1. 接口增加参数校验,防止攻击
2. 优化接口参数和逻辑
3. 规范接口错误响应
4. 增加接口描述

Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
上级 a5ca4bf2
...@@ -79,23 +79,37 @@ class ChatMessage(BaseModel): ...@@ -79,23 +79,37 @@ class ChatMessage(BaseModel):
} }
def get_folder_path(local_doc_id: str): def get_kb_path(local_doc_id: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "content") return os.path.join(KB_ROOT_PATH, local_doc_id)
def get_doc_path(local_doc_id: str):
return os.path.join(get_kb_path(local_doc_id), "content")
def get_vs_path(local_doc_id: str): def get_vs_path(local_doc_id: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store") return os.path.join(get_kb_path(local_doc_id), "vector_store")
def get_file_path(local_doc_id: str, doc_name: str): def get_file_path(local_doc_id: str, doc_name: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name) return os.path.join(get_doc_path(local_doc_id), doc_name)
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
async def upload_file( async def upload_file(
file: UploadFile = File(description="A single binary file"), file: UploadFile = File(description="A single binary file"),
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
): ):
saved_path = get_folder_path(knowledge_base_id) if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me", data=[])
saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path): if not os.path.exists(saved_path):
os.makedirs(saved_path) os.makedirs(saved_path)
...@@ -125,21 +139,25 @@ async def upload_files( ...@@ -125,21 +139,25 @@ async def upload_files(
], ],
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
): ):
saved_path = get_folder_path(knowledge_base_id) if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me", data=[])
saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path): if not os.path.exists(saved_path):
os.makedirs(saved_path) os.makedirs(saved_path)
filelist = [] filelist = []
for file in files: for file in files:
file_content = '' file_content = ''
file_path = os.path.join(saved_path, file.filename) file_path = os.path.join(saved_path, file.filename)
file_content = file.file.read() file_content = await file.read()
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content): if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
continue continue
with open(file_path, "ab+") as f: with open(file_path, "wb") as f:
f.write(file_content) f.write(file_content)
filelist.append(file_path) filelist.append(file_path)
if filelist: if filelist:
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id)) vs_path = get_vs_path(knowledge_base_id)
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
if len(loaded_files): if len(loaded_files):
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success" file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success"
return BaseResponse(code=200, msg=file_status) return BaseResponse(code=200, msg=file_status)
...@@ -163,16 +181,24 @@ async def list_kbs(): ...@@ -163,16 +181,24 @@ async def list_kbs():
async def list_docs( async def list_docs(
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1") knowledge_base_id: str = Query(..., description="Knowledge Base Name", example="kb1")
): ):
local_doc_folder = get_folder_path(knowledge_base_id) if not validate_kb_name(knowledge_base_id):
return ListDocsResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
kb_path = get_kb_path(knowledge_base_id)
local_doc_folder = get_doc_path(knowledge_base_id)
if not os.path.exists(kb_path):
return ListDocsResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found", data=[])
if not os.path.exists(local_doc_folder): if not os.path.exists(local_doc_folder):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} all_doc_names = []
all_doc_names = [ else:
doc all_doc_names = [
for doc in os.listdir(local_doc_folder) doc
if os.path.isfile(os.path.join(local_doc_folder, doc)) for doc in os.listdir(local_doc_folder)
] if os.path.isfile(os.path.join(local_doc_folder, doc))
]
return ListDocsResponse(data=all_doc_names) return ListDocsResponse(data=all_doc_names)
...@@ -181,11 +207,15 @@ async def delete_kb( ...@@ -181,11 +207,15 @@ async def delete_kb(
description="Knowledge Base Name", description="Knowledge Base Name",
example="kb1"), example="kb1"),
): ):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
# TODO: 确认是否支持批量删除知识库 # TODO: 确认是否支持批量删除知识库
knowledge_base_id = urllib.parse.unquote(knowledge_base_id) knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)): kb_path = get_kb_path(knowledge_base_id)
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} if not os.path.exists(kb_path):
shutil.rmtree(get_folder_path(knowledge_base_id)) return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
shutil.rmtree(kb_path)
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success") return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
...@@ -194,27 +224,30 @@ async def delete_doc( ...@@ -194,27 +224,30 @@ async def delete_doc(
description="Knowledge Base Name", description="Knowledge Base Name",
example="kb1"), example="kb1"),
doc_name: str = Query( doc_name: str = Query(
None, description="doc name", example="doc_name_1.pdf" ..., description="doc name", example="doc_name_1.pdf"
), ),
): ):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_id = urllib.parse.unquote(knowledge_base_id) knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)): if not os.path.exists(get_kb_path(knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
doc_path = get_file_path(knowledge_base_id, doc_name) doc_path = get_file_path(knowledge_base_id, doc_name)
if os.path.exists(doc_path): if os.path.exists(doc_path):
os.remove(doc_path) os.remove(doc_path)
remain_docs = await list_docs(knowledge_base_id) remain_docs = await list_docs(knowledge_base_id)
if len(remain_docs.data) == 0: if len(remain_docs.data) == 0:
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True) shutil.rmtree(get_kb_path(knowledge_base_id), ignore_errors=True)
return BaseResponse(code=200, msg=f"document {doc_name} delete success") return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else: else:
status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id)) status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
if "success" in status: if "success" in status:
return BaseResponse(code=200, msg=f"document {doc_name} delete success") return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else: else:
return BaseResponse(code=1, msg=f"document {doc_name} delete fail") return BaseResponse(code=500, msg=f"document {doc_name} delete fail")
else: else:
return BaseResponse(code=1, msg=f"document {doc_name} not found") return BaseResponse(code=404, msg=f"document {doc_name} not found")
async def update_doc( async def update_doc(
...@@ -222,23 +255,26 @@ async def update_doc( ...@@ -222,23 +255,26 @@ async def update_doc(
description="知识库名", description="知识库名",
example="kb1"), example="kb1"),
old_doc: str = Query( old_doc: str = Query(
None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf" ..., description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
), ),
new_doc: UploadFile = File(description="待上传文件"), new_doc: UploadFile = File(description="待上传文件"),
): ):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_id = urllib.parse.unquote(knowledge_base_id) knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)): if not os.path.exists(get_kb_path(knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
doc_path = get_file_path(knowledge_base_id, old_doc) doc_path = get_file_path(knowledge_base_id, old_doc)
if not os.path.exists(doc_path): if not os.path.exists(doc_path):
return BaseResponse(code=1, msg=f"document {old_doc} not found") return BaseResponse(code=404, msg=f"document {old_doc} not found")
else: else:
os.remove(doc_path) os.remove(doc_path)
delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id)) delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
if "fail" in delete_status: if "fail" in delete_status:
return BaseResponse(code=1, msg=f"document {old_doc} delete failed") return BaseResponse(code=500, msg=f"document {old_doc} delete failed")
else: else:
saved_path = get_folder_path(knowledge_base_id) saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path): if not os.path.exists(saved_path):
os.makedirs(saved_path) os.makedirs(saved_path)
...@@ -279,7 +315,7 @@ async def local_doc_chat( ...@@ -279,7 +315,7 @@ async def local_doc_chat(
): ):
vs_path = get_vs_path(knowledge_base_id) vs_path = get_vs_path(knowledge_base_id)
if not os.path.exists(vs_path): if not os.path.exists(vs_path):
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found") # return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
return ChatMessage( return ChatMessage(
question=question, question=question,
response=f"Knowledge base {knowledge_base_id} not found", response=f"Knowledge base {knowledge_base_id} not found",
...@@ -467,7 +503,7 @@ def api_start(host, port, **kwargs): ...@@ -467,7 +503,7 @@ def api_start(host, port, **kwargs):
# 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id # 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id
app.websocket("/local_doc_qa/stream_chat")(stream_chat) app.websocket("/local_doc_qa/stream_chat")(stream_chat)
app.get("/", response_model=BaseResponse)(document) app.get("/", response_model=BaseResponse, summary="swagger 文档")(document)
# 增加基于bing搜索的流式问答 # 增加基于bing搜索的流式问答
# 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia # 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia
...@@ -475,17 +511,17 @@ def api_start(host, port, **kwargs): ...@@ -475,17 +511,17 @@ def api_start(host, port, **kwargs):
# 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing # 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing
app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing) app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing)
app.post("/chat", response_model=ChatMessage)(chat) app.post("/chat", response_model=ChatMessage, summary="与模型对话")(chat)
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) app.post("/local_doc_qa/upload_file", response_model=BaseResponse, summary="上传文件到知识库")(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files) app.post("/local_doc_qa/upload_files", response_model=BaseResponse, summary="批量上传文件到知识库")(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat) app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage, summary="与知识库对话")(local_doc_chat)
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat) app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage, summary="与必应搜索对话")(bing_search_chat)
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs) app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse, summary="获取知识库列表")(list_kbs)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs)
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb) app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc) app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc)
app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc) app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc)
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg( local_doc_qa.init_cfg(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论