提交 ed7c5485 作者: imClumsyPanda

add delete knowledge base and delete files from knowledge base to webui

上级 331f39cd
...@@ -167,6 +167,7 @@ log/* ...@@ -167,6 +167,7 @@ log/*
vector_store/* vector_store/*
content/* content/*
api_content/* api_content/*
knowledge_base/*
llm/* llm/*
embedding/* embedding/*
......
...@@ -297,6 +297,16 @@ class LocalDocQA: ...@@ -297,6 +297,16 @@ class LocalDocQA:
status = vector_store.update_doc(filepath, docs) status = vector_store.update_doc(filepath, docs)
return status return status
def list_file_from_vector_store(self,
vs_path,
fullpath=False):
vector_store = load_vector_store(vs_path, self.embeddings)
docs = vector_store.list_docs()
if fullpath:
return docs
else:
return [os.path.split(doc)[-1] for doc in docs]
if __name__ == "__main__": if __name__ == "__main__":
# 初始化消息 # 初始化消息
......
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.faiss import dependable_faiss_import from langchain.vectorstores.faiss import dependable_faiss_import
from typing import Any, Callable, List, Tuple, Dict from typing import Any, Callable, List, Dict
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
import numpy as np import numpy as np
...@@ -109,15 +109,22 @@ class MyFAISS(FAISS, VectorStore): ...@@ -109,15 +109,22 @@ class MyFAISS(FAISS, VectorStore):
docs.append(doc) docs.append(doc)
return docs return docs
def delete_doc(self, source): def delete_doc(self, source: str or List[str]):
try: try:
if isinstance(source, str):
ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] == source] ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] == source]
else:
ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] in source]
if len(ids) == 0:
return f"docs delete fail"
else:
for id in ids: for id in ids:
index = list(self.index_to_docstore_id.keys())[list(self.index_to_docstore_id.values()).index(id)] index = list(self.index_to_docstore_id.keys())[list(self.index_to_docstore_id.values()).index(id)]
self.index_to_docstore_id.pop(index) self.index_to_docstore_id.pop(index)
self.docstore._dict.pop(id) self.docstore._dict.pop(id)
return f"docs delete success" return f"docs delete success"
except: except Exception as e:
print(e)
return f"docs delete fail" return f"docs delete fail"
def update_doc(self, source, new_docs): def update_doc(self, source, new_docs):
...@@ -125,5 +132,9 @@ class MyFAISS(FAISS, VectorStore): ...@@ -125,5 +132,9 @@ class MyFAISS(FAISS, VectorStore):
delete_len = self.delete_doc(source) delete_len = self.delete_doc(source)
ls = self.add_documents(new_docs) ls = self.add_documents(new_docs)
return f"docs update success" return f"docs update success"
except: except Exception as e:
print(e)
return f"docs update fail" return f"docs update fail"
def list_docs(self):
return list(set(v.metadata["source"] for v in self.docstore._dict.values()))
import gradio as gr import gradio as gr
import os
import shutil import shutil
from chains.local_doc_qa import LocalDocQA from chains.local_doc_qa import LocalDocQA
from configs.model_config import * from configs.model_config import *
import nltk import nltk
from models.base import (BaseAnswer,
AnswerResult)
import models.shared as shared import models.shared as shared
from models.loader.args import parser from models.loader.args import parser
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
import os
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
...@@ -161,20 +159,26 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte ...@@ -161,20 +159,26 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
file_status = "模型未完成加载,请先在加载模型后再导入文件" file_status = "模型未完成加载,请先在加载模型后再导入文件"
vs_path = None vs_path = None
logger.info(file_status) logger.info(file_status)
return vs_path, None, history + [[None, file_status]] return vs_path, None, history + [[None, file_status]], \
gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path) if vs_path else [])
def change_vs_name_input(vs_id, history): def change_vs_name_input(vs_id, history):
if vs_id == "新建知识库": if vs_id == "新建知识库":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\
gr.update(choices=[]), gr.update(visible=False)
else: else:
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
if "index.faiss" in os.listdir(vs_path): if "index.faiss" in os.listdir(vs_path):
file_status = f"已加载知识库{vs_id},请开始提问" file_status = f"已加载知识库{vs_id},请开始提问"
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \
vs_path, history + [[None, file_status]], \
gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path)), gr.update(visible=True)
else: else:
file_status = f"已选择知识库{vs_id},当前知识库中未上传文件,请先上传文件后,再开始提问" file_status = f"已选择知识库{vs_id},当前知识库中未上传文件,请先上传文件后,再开始提问"
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \
vs_path, history + [[None, file_status]] vs_path, history + [[None, file_status]], \
gr.update(choices=[]), gr.update(visible=True)
knowledge_base_test_mode_info = ("【注意】\n\n" knowledge_base_test_mode_info = ("【注意】\n\n"
...@@ -217,7 +221,7 @@ def add_vs_name(vs_name, chatbot): ...@@ -217,7 +221,7 @@ def add_vs_name(vs_name, chatbot):
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交" vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update( return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
visible=False), chatbot visible=False), chatbot, gr.update(visible=False)
else: else:
# 新建上传文件存储路径 # 新建上传文件存储路径
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_name, "content")): if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_name, "content")):
...@@ -228,7 +232,7 @@ def add_vs_name(vs_name, chatbot): ...@@ -228,7 +232,7 @@ def add_vs_name(vs_name, chatbot):
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """ vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update( return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update(
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot visible=False), gr.update(visible=False), gr.update(visible=True), chatbot, gr.update(visible=True)
# 自动化加载固定文件间中文件 # 自动化加载固定文件间中文件
...@@ -252,6 +256,38 @@ def reinit_vector_store(vs_id, history): ...@@ -252,6 +256,38 @@ def reinit_vector_store(vs_id, history):
def refresh_vs_list(): def refresh_vs_list():
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list()) return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
def delete_file(vs_id, files_to_delete, chatbot):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
status = local_doc_qa.delete_file_from_vector_store(vs_path=vs_path,
filepath=[os.path.join(content_path, file) for file in files_to_delete])
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
if "fail" in status:
vs_status = "文件删除失败。"
elif len(rested_files)>0:
vs_status = "文件删除成功。"
else:
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
logger.info(",".join(files_to_delete)+vs_status)
chatbot = chatbot + [[None, vs_status]]
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path)), chatbot
def delete_vs(vs_id, chatbot):
try:
shutil.rmtree(os.path.join(KB_ROOT_PATH, vs_id))
status = f"成功删除知识库{vs_id}"
logger.info(status)
chatbot = chatbot + [[None, status]]
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=False), chatbot, gr.update(visible=False)
except Exception as e:
logger.error(e)
status = f"删除知识库{vs_id}失败"
chatbot = chatbot + [[None, status]]
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=True), chatbot, gr.update(visible=True)
block_css = """.importantButton { block_css = """.importantButton {
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
...@@ -318,6 +354,7 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -318,6 +354,7 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
interactive=True, interactive=True,
visible=True) visible=True)
vs_add = gr.Button(value="添加至知识库选项", visible=True) vs_add = gr.Button(value="添加至知识库选项", visible=True)
vs_delete = gr.Button("删除本知识库", visible=False)
file2vs = gr.Column(visible=False) file2vs = gr.Column(visible=False)
with file2vs: with file2vs:
# load_vs = gr.Button("加载知识库") # load_vs = gr.Button("加载知识库")
...@@ -336,28 +373,40 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -336,28 +373,40 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
file_count="directory", file_count="directory",
show_label=False) show_label=False)
load_folder_button = gr.Button("上传文件夹并加载知识库") load_folder_button = gr.Button("上传文件夹并加载知识库")
with gr.Tab("删除文件"):
files_to_delete = gr.CheckboxGroup(choices=[],
label="请从知识库已有文件中选择要删除的文件",
interactive=True)
delete_file_button = gr.Button("从知识库中删除选中文件")
vs_refresh.click(fn=refresh_vs_list, vs_refresh.click(fn=refresh_vs_list,
inputs=[], inputs=[],
outputs=select_vs) outputs=select_vs)
vs_add.click(fn=add_vs_name, vs_add.click(fn=add_vs_name,
inputs=[vs_name, chatbot], inputs=[vs_name, chatbot],
outputs=[select_vs, vs_name, vs_add, file2vs, chatbot]) outputs=[select_vs, vs_name, vs_add, file2vs, chatbot, vs_delete])
vs_delete.click(fn=delete_vs,
inputs=[select_vs, chatbot],
outputs=[select_vs, vs_name, vs_add, file2vs, chatbot, vs_delete])
select_vs.change(fn=change_vs_name_input, select_vs.change(fn=change_vs_name_input,
inputs=[select_vs, chatbot], inputs=[select_vs, chatbot],
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) outputs=[vs_name, vs_add, file2vs, vs_path, chatbot, files_to_delete, vs_delete])
load_file_button.click(get_vector_store, load_file_button.click(get_vector_store,
show_progress=True, show_progress=True,
inputs=[select_vs, files, sentence_size, chatbot, vs_add, vs_add], inputs=[select_vs, files, sentence_size, chatbot, vs_add, vs_add],
outputs=[vs_path, files, chatbot], ) outputs=[vs_path, files, chatbot, files_to_delete], )
load_folder_button.click(get_vector_store, load_folder_button.click(get_vector_store,
show_progress=True, show_progress=True,
inputs=[select_vs, folder_files, sentence_size, chatbot, vs_add, inputs=[select_vs, folder_files, sentence_size, chatbot, vs_add,
vs_add], vs_add],
outputs=[vs_path, folder_files, chatbot], ) outputs=[vs_path, folder_files, chatbot, files_to_delete], )
flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged") flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged")
query.submit(get_answer, query.submit(get_answer,
[query, vs_path, chatbot, mode], [query, vs_path, chatbot, mode],
[chatbot, query]) [chatbot, query])
delete_file_button.click(delete_file,
show_progress=True,
inputs=[select_vs, files_to_delete, chatbot],
outputs=[files_to_delete, chatbot])
with gr.Tab("知识库测试 Beta"): with gr.Tab("知识库测试 Beta"):
with gr.Row(): with gr.Row():
with gr.Column(scale=10): with gr.Column(scale=10):
...@@ -488,9 +537,9 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -488,9 +537,9 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
load_model_button.click(reinit_model, show_progress=True, load_model_button.click(reinit_model, show_progress=True,
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2,
use_lora, top_k, chatbot], outputs=chatbot) use_lora, top_k, chatbot], outputs=chatbot)
load_knowlege_button = gr.Button("重新构建知识库") # load_knowlege_button = gr.Button("重新构建知识库")
load_knowlege_button.click(reinit_vector_store, show_progress=True, # load_knowlege_button.click(reinit_vector_store, show_progress=True,
inputs=[select_vs, chatbot], outputs=chatbot) # inputs=[select_vs, chatbot], outputs=chatbot)
demo.load( demo.load(
fn=refresh_vs_list, fn=refresh_vs_list,
inputs=None, inputs=None,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论