提交 2c1fd2bd 作者: imClumsyPanda

add api.py

上级 e0cf2601
...@@ -97,9 +97,9 @@ async def upload_file( ...@@ -97,9 +97,9 @@ async def upload_file(
files: Annotated[ files: Annotated[
List[UploadFile], File(description="Multiple files as UploadFile") List[UploadFile], File(description="Multiple files as UploadFile")
], ],
local_doc_id: str = Form(..., description="Local document ID", example="doc_id_1"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
): ):
saved_path = get_folder_path(local_doc_id) saved_path = get_folder_path(knowledge_base_id)
if not os.path.exists(saved_path): if not os.path.exists(saved_path):
os.makedirs(saved_path) os.makedirs(saved_path)
for file in files: for file in files:
...@@ -107,17 +107,17 @@ async def upload_file( ...@@ -107,17 +107,17 @@ async def upload_file(
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(file.file.read()) f.write(file.file.read())
local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(local_doc_id)) local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(knowledge_base_id))
return BaseResponse() return BaseResponse()
async def list_docs( async def list_docs(
local_doc_id: Optional[str] = Query(description="Document ID", example="doc_id1") knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1")
): ):
if local_doc_id: if knowledge_base_id:
local_doc_folder = get_folder_path(local_doc_id) local_doc_folder = get_folder_path(knowledge_base_id)
if not os.path.exists(local_doc_folder): if not os.path.exists(local_doc_folder):
return {"code": 1, "msg": f"document {local_doc_id} not found"} return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
all_doc_names = [ all_doc_names = [
doc doc
for doc in os.listdir(local_doc_folder) for doc in os.listdir(local_doc_folder)
...@@ -138,34 +138,34 @@ async def list_docs( ...@@ -138,34 +138,34 @@ async def list_docs(
async def delete_docs( async def delete_docs(
local_doc_id: str = Form(..., description="local doc id", example="doc_id_1"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
doc_name: Optional[str] = Form( doc_name: Optional[str] = Form(
None, description="doc name", example="doc_name_1.pdf" None, description="doc name", example="doc_name_1.pdf"
), ),
): ):
if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id)): if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id)):
return {"code": 1, "msg": f"document {local_doc_id} not found"} return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
if doc_name: if doc_name:
doc_path = get_file_path(local_doc_id, doc_name) doc_path = get_file_path(knowledge_base_id, doc_name)
if os.path.exists(doc_path): if os.path.exists(doc_path):
os.remove(doc_path) os.remove(doc_path)
else: else:
return {"code": 1, "msg": f"document {doc_name} not found"} return {"code": 1, "msg": f"document {doc_name} not found"}
remain_docs = await list_docs(local_doc_id) remain_docs = await list_docs(knowledge_base_id)
if remain_docs["code"] != 0 or len(remain_docs["data"]) == 0: if remain_docs["code"] != 0 or len(remain_docs["data"]) == 0:
shutil.rmtree(get_folder_path(local_doc_id), ignore_errors=True) shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
else: else:
local_doc_qa.init_knowledge_vector_store( local_doc_qa.init_knowledge_vector_store(
get_folder_path(local_doc_id), get_vs_path(local_doc_id) get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id)
) )
else: else:
shutil.rmtree(get_folder_path(local_doc_id)) shutil.rmtree(get_folder_path(knowledge_base_id))
return BaseResponse() return BaseResponse()
async def chat( async def chat(
local_doc_id: str = Body(..., description="Document ID", example="doc_id1"), knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
question: str = Body(..., description="Question", example="工伤保险是什么?"), question: str = Body(..., description="Question", example="工伤保险是什么?"),
history: List[List[str]] = Body( history: List[List[str]] = Body(
[], [],
...@@ -178,9 +178,9 @@ async def chat( ...@@ -178,9 +178,9 @@ async def chat(
], ],
), ),
): ):
vs_path = os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store") vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
if not os.path.exists(vs_path): if not os.path.exists(vs_path):
raise ValueError(f"Document {local_doc_id} not found") raise ValueError(f"Knowledge base {knowledge_base_id} not found")
for resp, history in local_doc_qa.get_knowledge_based_answer( for resp, history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True query=question, vs_path=vs_path, chat_history=history, streaming=True
...@@ -200,12 +200,12 @@ async def chat( ...@@ -200,12 +200,12 @@ async def chat(
) )
async def stream_chat(websocket: WebSocket, local_doc_id: str): async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
await websocket.accept() await websocket.accept()
vs_path = os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store") vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
if not os.path.exists(vs_path): if not os.path.exists(vs_path):
await websocket.send_json({"error": f"document {local_doc_id} not found"}) await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
await websocket.close() await websocket.close()
return return
...@@ -288,7 +288,7 @@ def main(): ...@@ -288,7 +288,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
app = FastAPI() app = FastAPI()
app.websocket("/chat-docs/stream-chat/{local_doc_id}")(stream_chat) app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
app.post("/chat-docs/chat", response_model=ChatMessage)(chat) app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file) app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs) app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
......
...@@ -184,7 +184,8 @@ class LocalDocQA: ...@@ -184,7 +184,8 @@ class LocalDocQA:
torch_gc(DEVICE) torch_gc(DEVICE)
else: else:
if not vs_path: if not vs_path:
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vs_path = os.path.join(VS_ROOT_PATH,
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
vector_store = FAISS.from_documents(docs, self.embeddings) vector_store = FAISS.from_documents(docs, self.embeddings)
torch_gc(DEVICE) torch_gc(DEVICE)
......
...@@ -36,9 +36,9 @@ USE_PTUNING_V2 = False ...@@ -36,9 +36,9 @@ USE_PTUNING_V2 = False
# LLM running device # LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store", "") VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store")
UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "") UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content")
API_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_content") API_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_content")
......
...@@ -7,7 +7,8 @@ def torch_gc(DEVICE): ...@@ -7,7 +7,8 @@ def torch_gc(DEVICE):
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
try: try:
torch.mps.empty_cache() from torch.mps import empty_cache
empty_cache()
except Exception as e: except Exception as e:
print(e) print(e)
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
\ No newline at end of file
...@@ -95,12 +95,12 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to ...@@ -95,12 +95,12 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
def get_vector_store(vs_id, files, history): def get_vector_store(vs_id, files, history):
vs_path = VS_ROOT_PATH + vs_id vs_path = os.path.join(VS_ROOT_PATH, vs_id)
filelist = [] filelist = []
for file in files: for file in files:
filename = os.path.split(file.name)[-1] filename = os.path.split(file.name)[-1]
shutil.move(file.name, UPLOAD_ROOT_PATH + filename) shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, filename))
filelist.append(UPLOAD_ROOT_PATH + filename) filelist.append(os.path.join(UPLOAD_ROOT_PATH, filename))
if local_doc_qa.llm and local_doc_qa.embeddings: if local_doc_qa.llm and local_doc_qa.embeddings:
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path) vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
if len(loaded_files): if len(loaded_files):
...@@ -118,7 +118,7 @@ def change_vs_name_input(vs_id): ...@@ -118,7 +118,7 @@ def change_vs_name_input(vs_id):
if vs_id == "新建知识库": if vs_id == "新建知识库":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None
else: else:
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), VS_ROOT_PATH + vs_id return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, vs_id)
def change_mode(mode): def change_mode(mode):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论