提交 b1ba2003 作者: imClumsyPanda

update webui.py

上级 a0729058
...@@ -252,7 +252,7 @@ class LocalDocQA: ...@@ -252,7 +252,7 @@ class LocalDocQA:
logger.info(f"{file} 未能成功加载") logger.info(f"{file} 未能成功加载")
if len(docs) > 0: if len(docs) > 0:
logger.info("文件加载完毕,正在生成向量库") logger.info("文件加载完毕,正在生成向量库")
if vs_path and os.path.isdir(vs_path): if vs_path and os.path.isdir(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = load_vector_store(vs_path, self.embeddings) vector_store = load_vector_store(vs_path, self.embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
torch_gc() torch_gc()
......
...@@ -50,18 +50,19 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR ...@@ -50,18 +50,19 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
enumerate(resp["source_documents"])]) enumerate(resp["source_documents"])])
history[-1][-1] += source history[-1][-1] += source
yield history, "" yield history, ""
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir(
for resp, history in local_doc_qa.get_knowledge_based_answer( vs_path):
query=query, vs_path=vs_path, chat_history=history, streaming=streaming): for resp, history in local_doc_qa.get_knowledge_based_answer(
source = "\n\n" query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
source += "".join( source = "\n\n"
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n""" source += "".join(
f"""{doc.page_content}\n""" [f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
f"""</details>""" f"""{doc.page_content}\n"""
for i, doc in f"""</details>"""
enumerate(resp["source_documents"])]) for i, doc in
history[-1][-1] += source enumerate(resp["source_documents"])])
yield history, "" history[-1][-1] += source
yield history, ""
elif mode == "知识库测试": elif mode == "知识库测试":
if os.path.exists(vs_path): if os.path.exists(vs_path):
resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path, resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path,
...@@ -124,7 +125,6 @@ def init_model(): ...@@ -124,7 +125,6 @@ def init_model():
def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k,
history): history):
try: try:
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
llm_model_ins.history_len = llm_history_len llm_model_ins.history_len = llm_history_len
...@@ -229,16 +229,17 @@ def add_vs_name(vs_name, chatbot): ...@@ -229,16 +229,17 @@ def add_vs_name(vs_name, chatbot):
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update( return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update(
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot visible=False), gr.update(visible=False), gr.update(visible=True), chatbot
# 自动化加载固定文件间中文件 # 自动化加载固定文件间中文件
def init_set_vector_store(content_dir,vs_id,history): def reinit_vector_store(vs_id, history):
try: try:
shutil.rmtree(VS_ROOT_PATH) shutil.rmtree(VS_ROOT_PATH)
vs_path = os.path.join(VS_ROOT_PATH, vs_id) vs_path = os.path.join(VS_ROOT_PATH, vs_id)
sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0, sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
label="文本入库分句长度限制", label="文本入库分句长度限制",
interactive=True, visible=True) interactive=True, visible=True)
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(content_dir, vs_path, sentence_size) vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(UPLOAD_ROOT_PATH, vs_path, sentence_size)
model_status = """知识库构建成功""" model_status = """知识库构建成功"""
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
...@@ -487,8 +488,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -487,8 +488,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2,
use_lora, top_k, chatbot], outputs=chatbot) use_lora, top_k, chatbot], outputs=chatbot)
load_knowlege_button = gr.Button("重新构建知识库") load_knowlege_button = gr.Button("重新构建知识库")
load_knowlege_button.click(init_set_vector_store, show_progress=True, load_knowlege_button.click(reinit_vector_store, show_progress=True,
inputs=[UPLOAD_ROOT_PATH, select_vs,chatbot], outputs=chatbot) inputs=[select_vs, chatbot], outputs=chatbot)
(demo (demo
.queue(concurrency_count=3) .queue(concurrency_count=3)
.launch(server_name='0.0.0.0', .launch(server_name='0.0.0.0',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论