提交 c7106317 作者: imClumsyPanda

update webui.py

上级 89fe20b5
...@@ -25,8 +25,6 @@ def get_vs_list(): ...@@ -25,8 +25,6 @@ def get_vs_list():
return lst_default + lst return lst_default + lst
vs_list = get_vs_list()
embedding_model_dict_list = list(embedding_model_dict.keys()) embedding_model_dict_list = list(embedding_model_dict.keys())
llm_model_dict_list = list(llm_model_dict.keys()) llm_model_dict_list = list(llm_model_dict.keys())
...@@ -44,11 +42,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR ...@@ -44,11 +42,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
query=query, chat_history=history, streaming=streaming): query=query, chat_history=history, streaming=streaming):
source = "\n\n" source = "\n\n"
source += "".join( source += "".join(
[f"""<details> <summary>出处 [{i + 1}] <a href="{doc.metadata["source"]}" target="_blank">{doc.metadata["source"]}</a> </summary>\n""" [
f"""{doc.page_content}\n""" f"""<details> <summary>出处 [{i + 1}] <a href="{doc.metadata["source"]}" target="_blank">{doc.metadata["source"]}</a> </summary>\n"""
f"""</details>""" f"""{doc.page_content}\n"""
for i, doc in f"""</details>"""
enumerate(resp["source_documents"])]) for i, doc in
enumerate(resp["source_documents"])])
history[-1][-1] += source history[-1][-1] += source
yield history, "" yield history, ""
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path): elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
...@@ -89,7 +88,6 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR ...@@ -89,7 +88,6 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
else: else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history, for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming): streaming=streaming):
resp = answer_result.llm_output["answer"] resp = answer_result.llm_output["answer"]
history = answer_result.history history = answer_result.history
history[-1][-1] = resp + ( history[-1][-1] = resp + (
...@@ -99,9 +97,15 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR ...@@ -99,9 +97,15 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME) flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
def init_model(llm_model: BaseAnswer = None): def init_model():
args = parser.parse_args()
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try: try:
local_doc_qa.init_cfg(llm_model=llm_model) local_doc_qa.init_cfg(llm_model=llm_model_ins)
generator = local_doc_qa.llm.generatorAnswer("你好") generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator: for answer_result in generator:
print(answer_result.llm_output) print(answer_result.llm_output)
...@@ -119,7 +123,9 @@ def init_model(llm_model: BaseAnswer = None): ...@@ -119,7 +123,9 @@ def init_model(llm_model: BaseAnswer = None):
return reply return reply
def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history): def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k,
history):
try: try:
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
llm_model_ins.history_len = llm_history_len llm_model_ins.history_len = llm_history_len
...@@ -138,8 +144,6 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u ...@@ -138,8 +144,6 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(VS_ROOT_PATH, vs_id) vs_path = os.path.join(VS_ROOT_PATH, vs_id)
filelist = [] filelist = []
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
if local_doc_qa.llm and local_doc_qa.embeddings: if local_doc_qa.llm and local_doc_qa.embeddings:
if isinstance(files, list): if isinstance(files, list):
for file in files: for file in files:
...@@ -166,9 +170,8 @@ def change_vs_name_input(vs_id, history): ...@@ -166,9 +170,8 @@ def change_vs_name_input(vs_id, history):
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
else: else:
file_status = f"已加载知识库{vs_id},请开始提问" file_status = f"已加载知识库{vs_id},请开始提问"
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \
vs_id), history + [ os.path.join(VS_ROOT_PATH, vs_id), history + [[None, file_status]]
[None, file_status]]
knowledge_base_test_mode_info = ("【注意】\n\n" knowledge_base_test_mode_info = ("【注意】\n\n"
...@@ -206,19 +209,29 @@ def change_chunk_conent(mode, label_conent, history): ...@@ -206,19 +209,29 @@ def change_chunk_conent(mode, label_conent, history):
return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]] return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]]
def add_vs_name(vs_name, vs_list, chatbot): def add_vs_name(vs_name, chatbot):
if vs_name in vs_list: if vs_name in get_vs_list():
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交" vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True), vs_list, gr.update(visible=True), gr.update(visible=True), gr.update( return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
visible=False), chatbot visible=False), chatbot
else: else:
# 新建上传文件存储路径
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_name)):
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_name))
# 新建向量库存储路径
if not os.path.exists(os.path.join(VS_ROOT_PATH, vs_name)):
os.makedirs(os.path.join(VS_ROOT_PATH, vs_name))
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """ vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True, choices=[vs_name] + vs_list, value=vs_name), [vs_name] + vs_list, gr.update( return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update(
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot visible=False), gr.update(visible=False), gr.update(visible=True), chatbot
def refresh_vs_list():
return gr.update(choices=get_vs_list())
block_css = """.importantButton { block_css = """.importantButton {
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
border: none !important; border: none !important;
...@@ -232,7 +245,7 @@ webui_title = """ ...@@ -232,7 +245,7 @@ webui_title = """
# 🎉langchain-ChatGLM WebUI🎉 # 🎉langchain-ChatGLM WebUI🎉
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM) 👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
""" """
default_vs = vs_list[0] if len(vs_list) > 1 else "为空" default_vs = get_vs_list()[0] if len(get_vs_list()) > 1 else "为空"
init_message = f"""欢迎使用 langchain-ChatGLM Web UI! init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
请在右侧切换模式,目前支持直接与 LLM 模型对话或基于本地知识库问答。 请在右侧切换模式,目前支持直接与 LLM 模型对话或基于本地知识库问答。
...@@ -243,16 +256,7 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI! ...@@ -243,16 +256,7 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
""" """
# 初始化消息 # 初始化消息
args = None model_status = init_model()
args = parser.parse_args()
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
model_status = init_model(llm_model=llm_model_ins)
default_theme_args = dict( default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'], font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
...@@ -260,10 +264,9 @@ default_theme_args = dict( ...@@ -260,10 +264,9 @@ default_theme_args = dict(
) )
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo: with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
vs_path, file_status, model_status, vs_list = gr.State( vs_path, file_status, model_status = gr.State(
os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""), gr.State(""), gr.State( os.path.join(VS_ROOT_PATH, get_vs_list()[0]) if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
model_status), gr.State(vs_list) model_status)
gr.Markdown(webui_title) gr.Markdown(webui_title)
with gr.Tab("对话"): with gr.Tab("对话"):
with gr.Row(): with gr.Row():
...@@ -283,10 +286,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -283,10 +286,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
inputs=[mode, chatbot], inputs=[mode, chatbot],
outputs=[vs_setting, knowledge_set, chatbot]) outputs=[vs_setting, knowledge_set, chatbot])
with vs_setting: with vs_setting:
select_vs = gr.Dropdown(vs_list.value, vs_refresh = gr.Button("更新已有知识库选项")
select_vs = gr.Dropdown(get_vs_list(),
label="请选择要加载的知识库", label="请选择要加载的知识库",
interactive=True, interactive=True,
value=vs_list.value[0] if len(vs_list.value) > 0 else None value=get_vs_list()[0] if len(get_vs_list()) > 0 else None
) )
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文", vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
lines=1, lines=1,
...@@ -302,19 +306,21 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -302,19 +306,21 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
interactive=True, visible=True) interactive=True, visible=True)
with gr.Tab("上传文件"): with gr.Tab("上传文件"):
files = gr.File(label="添加文件", files = gr.File(label="添加文件",
file_types=['.txt', '.md', '.docx', '.pdf'], file_types=['.txt', '.md', '.docx', '.pdf', '.png', '.jpg'],
file_count="multiple", file_count="multiple",
show_label=False) show_label=False)
load_file_button = gr.Button("上传文件并加载知识库") load_file_button = gr.Button("上传文件并加载知识库")
with gr.Tab("上传文件夹"): with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件", folder_files = gr.File(label="添加文件",
# file_types=['.txt', '.md', '.docx', '.pdf'],
file_count="directory", file_count="directory",
show_label=False) show_label=False)
load_folder_button = gr.Button("上传文件夹并加载知识库") load_folder_button = gr.Button("上传文件夹并加载知识库")
vs_refresh.click(fn=refresh_vs_list,
inputs=[],
outputs=select_vs)
vs_add.click(fn=add_vs_name, vs_add.click(fn=add_vs_name,
inputs=[vs_name, vs_list, chatbot], inputs=[vs_name, chatbot],
outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot]) outputs=[select_vs, vs_name, vs_add, file2vs, chatbot])
select_vs.change(fn=change_vs_name_input, select_vs.change(fn=change_vs_name_input,
inputs=[select_vs, chatbot], inputs=[select_vs, chatbot],
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
...@@ -366,10 +372,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -366,10 +372,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
inputs=[chunk_conent, gr.Textbox(value="chunk_conent", visible=False), chatbot], inputs=[chunk_conent, gr.Textbox(value="chunk_conent", visible=False), chatbot],
outputs=[chunk_sizes, chatbot]) outputs=[chunk_sizes, chatbot])
with vs_setting: with vs_setting:
select_vs = gr.Dropdown(vs_list.value, vs_refresh = gr.Button("更新已有知识库选项")
select_vs = gr.Dropdown(get_vs_list(),
label="请选择要加载的知识库", label="请选择要加载的知识库",
interactive=True, interactive=True,
value=vs_list.value[0] if len(vs_list.value) > 0 else None) value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文", vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
lines=1, lines=1,
interactive=True, interactive=True,
...@@ -402,9 +409,12 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -402,9 +409,12 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
interactive=True) interactive=True)
load_conent_button = gr.Button("添加内容并加载知识库") load_conent_button = gr.Button("添加内容并加载知识库")
# 将上传的文件保存到content文件夹下,并更新下拉框 # 将上传的文件保存到content文件夹下,并更新下拉框
vs_refresh.click(fn=refresh_vs_list,
inputs=[],
outputs=select_vs)
vs_add.click(fn=add_vs_name, vs_add.click(fn=add_vs_name,
inputs=[vs_name, vs_list, chatbot], inputs=[vs_name, chatbot],
outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot]) outputs=[select_vs, vs_name, vs_add, file2vs, chatbot])
select_vs.change(fn=change_vs_name_input, select_vs.change(fn=change_vs_name_input,
inputs=[select_vs, chatbot], inputs=[select_vs, chatbot],
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
...@@ -455,8 +465,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -455,8 +465,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
label="向量匹配 top k", interactive=True) label="向量匹配 top k", interactive=True)
load_model_button = gr.Button("重新加载模型") load_model_button = gr.Button("重新加载模型")
load_model_button.click(reinit_model, show_progress=True, load_model_button.click(reinit_model, show_progress=True,
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2,
top_k, chatbot], outputs=chatbot) use_lora, top_k, chatbot], outputs=chatbot)
(demo (demo
.queue(concurrency_count=3) .queue(concurrency_count=3)
...@@ -464,4 +474,4 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -464,4 +474,4 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
server_port=7860, server_port=7860,
show_api=False, show_api=False,
share=False, share=False,
inbrowser=False)) inbrowser=False))
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论