提交 966def8c 作者: imClumsyPanda

add stream support to webui.py

上级 b4aefca5
from configs.model_config import *
from chains.local_doc_qa import LocalDocQA
import os
import nltk
import uvicorn
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
from starlette.responses import RedirectResponse
app = FastAPI()
global local_doc_qa, vs_path
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 10
# LLM input history length
LLM_HISTORY_LEN = 3
# Show reply with source text from input document
REPLY_WITH_SOURCE = False
class Query(BaseModel):
query: str
@app.get('/')
async def document():
return RedirectResponse(url="/docs")
@app.on_event("startup")
async def get_local_doc_qa():
global local_doc_qa
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=LLM_MODEL,
embedding_model=EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
llm_history_len=LLM_HISTORY_LEN,
top_k=VECTOR_SEARCH_TOP_K)
@app.post("/file")
async def upload_file(UserFile: UploadFile=File(...),):
global vs_path
response = {
"msg": None,
"status": 0
}
try:
filepath = './content/' + UserFile.filename
content = await UserFile.read()
# print(UserFile.filename)
with open(filepath, 'wb') as f:
f.write(content)
vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
response = {
'msg': 'seccess' if len(files)>0 else 'fail',
'status': 1 if len(files)>0 else 0,
'loaded_files': files
}
except Exception as err:
response["message"] = err
return response
@app.post("/qa")
async def get_answer(query: str = ""):
response = {
"status": 0,
"message": "",
"answer": None
}
global vs_path
history = []
try:
resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=history)
if REPLY_WITH_SOURCE:
response["answer"] = resp
else:
response['answer'] = resp["result"]
response["message"] = 'successful'
response["status"] = 1
except Exception as err:
response["message"] = err
return response
if __name__ == "__main__":
uvicorn.run(
app=app,
host='0.0.0.0',
port=8100,
reload=True,
)
......@@ -141,7 +141,6 @@ class LocalDocQA:
if streaming:
for result, history in self.llm._call(prompt=prompt,
history=chat_history):
history[-1] = list(history[-1])
history[-1][0] = query
response = {"query": query,
"result": result,
......@@ -150,7 +149,6 @@ class LocalDocQA:
else:
result, history = self.llm._call(prompt=prompt,
history=chat_history)
history[-1] = list(history[-1])
history[-1][0] = query
response = {"query": query,
"result": result,
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.prompts import PromptTemplate\n",
"from lib.embeds import MyEmbeddings\n",
"from lib.faiss import FAISSVS\n",
"from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
"from langchain.chains.llm import LLMChain\n",
"from lib.chatglm_llm import ChatGLM, AlpacaGLM\n",
"from lib.config import *\n",
"from lib.utils import get_docs\n",
"\n",
"\n",
"class LocalDocQA:\n",
" def __init__(self, \n",
" embedding_model=EMBEDDING_MODEL, \n",
" embedding_device=EMBEDDING_DEVICE, \n",
" llm_model=LLM_MODEL, \n",
" llm_device=LLM_DEVICE, \n",
" llm_history_len=LLM_HISTORY_LEN, \n",
" top_k=VECTOR_SEARCH_TOP_K,\n",
" vs_name = VS_NAME\n",
" ) -> None:\n",
" \n",
" torch.cuda.empty_cache()\n",
" torch.cuda.empty_cache()\n",
"\n",
" self.embedding_model = embedding_model\n",
" self.llm_model = llm_model\n",
" self.embedding_device = embedding_device\n",
" self.llm_device = llm_device\n",
" self.llm_history_len = llm_history_len\n",
" self.top_k = top_k\n",
" self.vs_name = vs_name\n",
"\n",
" self.llm = AlpacaGLM()\n",
" self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)\n",
"\n",
" self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])\n",
" self.load_vector_store(vs_name)\n",
"\n",
" self.prompt = PromptTemplate(\n",
" template=PROMPT_TEMPLATE,\n",
" input_variables=[\"context\", \"question\"]\n",
" )\n",
" self.search_params = {\n",
" \"engine\": \"bing\",\n",
" \"gl\": \"us\",\n",
" \"hl\": \"en\",\n",
" \"serpapi_api_key\": \"\"\n",
" }\n",
"\n",
" def init_knowledge_vector_store(self, vs_name: str):\n",
" \n",
" docs = get_docs(KNOWLEDGE_PATH)\n",
" vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
" vs_path = VECTORSTORE_PATH + vs_name\n",
" vector_store.save_local(vs_path)\n",
"\n",
" def add_knowledge_to_vector_store(self, vs_name: str):\n",
" docs = get_docs(ADD_KNOWLEDGE_PATH)\n",
" new_vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
" vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings) \n",
" vector_store.merge_from(new_vector_store)\n",
" vector_store.save_local(VECTORSTORE_PATH + vs_name)\n",
"\n",
" def load_vector_store(self, vs_name: str):\n",
" self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)\n",
"\n",
" # def get_search_based_answer(self, query):\n",
" \n",
" # search = SerpAPIWrapper(params=self.search_params)\n",
" # docs = search.run(query)\n",
" # search_chain = load_qa_chain(self.llm, chain_type=\"stuff\")\n",
" # answer = search_chain.run(input_documents=docs, question=query)\n",
"\n",
" # return answer\n",
" \n",
" def get_knowledge_based_answer(self, query):\n",
" \n",
" docs = self.vector_store.max_marginal_relevance_search(query)\n",
" print(f'召回的文档和相似度分数:{docs}')\n",
" # 这里 doc[1] 就是对应的score \n",
" docs = [doc[0] for doc in docs]\n",
" \n",
" document_prompt = PromptTemplate(\n",
" input_variables=[\"page_content\"], template=\"Context:\\n{page_content}\"\n",
" )\n",
" llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)\n",
" combine_documents_chain = StuffDocumentsChain(\n",
" llm_chain=llm_chain,\n",
" document_variable_name=\"context\",\n",
" document_prompt=document_prompt,\n",
" )\n",
" answer = combine_documents_chain.run(\n",
" input_documents=docs, question=query\n",
" )\n",
"\n",
" self.llm.history[-1][0] = query\n",
" self.llm.history[-1][-1] = answer\n",
" return answer, docs, self.llm.history"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d4342213010c4ed2ad5b04694aa436d6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"qa = LocalDocQA()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"召回的文档和相似度分数:[(Document(page_content='****** LOGI APT Group Intelligence Research Yearbook APT Knowledge Graph APT组织情报 研究年鉴', metadata={'source': './KnowledgeStore/APT group Intelligence Research handbook-2022.pdf', 'page': 0}), 0.45381865), (Document(page_content='9 MANDIANT APT42: Crooked Charms, Cons and Compromises FIGURE 8. APT42 impersonates University of Oxford vaccinologist. APT42 Credential harvesting page masquerading as a Yahoo login portal.', metadata={'source': './KnowledgeStore/APT42_Crooked_Charms_Cons_and_Compromises.pdf', 'page': 8}), 0.4535672), (Document(page_content='The origin story of APT32 macros T H R E A T R E S E A R C H R E P O R T R u n n i n g t h r o u g h a l l t h e S U O f i l e s t r u c t u r e s i s l a b o r i o u s a n d d i d n ’ t y i e l d m u c h m o r e t h a n a s t r i n g d u m p w o u l d h a v e d o n e a n y w a y . W e f i n d p a t h s t o s o u r c e c o d e f i l e s , p r o j e c t n a m e s , e t c . W e c a n i n f e r f r o m t h e m y r i a d o f r e f e r e n c e s i n XmlPackageOptions , O u t l i n i n g S t a t e D i r , e t c . , t h a t t h e HtaDotnet a n d ShellcodeLoader s o l u t i o n s w e r e o r i g i n a l l y u n d e r t h e f o l d e r p a t h G:\\\\WebBuilder\\\\Gift_HtaDotnet\\\\ . T h i s i s a l s o s u p p o r t e d b y t h e P D B p a t h s o f o l d e r b u i l t b i n a r i e s w i t h i n t h e b r o a d e r S t r i k e S u i t G i f t p a c k a g e . F r o m l o o k i n g a t D e b u g g e r W a t c h e s v a l u e s i n o t h e r p r o j e c t s , w e c a n s e e t h a t t h e m a l w a r e d e v e l o p e r w a s a c t i v e l y d e b u g g i n g t h e h i s t o r i c a l p r o g r a m s . S U O f i l e D e b u g g e r W a t c h e s WebBuilder/HtaDotNet/HtaDotnet.v11.suo result WebBuilder/ShellcodeLoader/.vs/L/v14/.suo (char)77 WebBuilder/ShellcodeLoader/L.suo (char)77 3 4 04/2022', metadata={'source': './KnowledgeStore/Stairwell-threat-report-The-origin-of-APT32-macros.pdf', 'page': 33}), 0.38091612), (Document(page_content='2 APTs and COVID-19: How advanced persistent threats use the coronavirus as a lureTable of contents Introduction: APT groups using COVID-19 .........................................................', metadata={'source': './KnowledgeStore/200407-MWB-COVID-White-Paper_Final.pdf', 'page': 1}), 0.44476452)]\n"
]
}
],
"source": [
"query = r\"\"\"make a brief introduction of APT?\"\"\"\n",
"ans, docs, _ = qa.get_knowledge_based_answer(query)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\nAnswer: APT stands for Advanced Persistent Threat, which is a type of malicious cyberattack that is carried out by a sophisticated hacker group or state-sponsored organization. APTs are designed to remain undetected for a long period of time and are often used to steal sensitive data or disrupt critical infrastructure.'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ans"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chatgpt",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
......@@ -74,14 +74,17 @@ class ChatGLM(LLM):
history: List[List[str]] = [],
stop: Optional[List[str]] = None) -> str:
if self.streaming:
history = history + [[None, ""]]
for stream_resp, history in self.model.stream_chat(
for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
self.tokenizer,
prompt,
history=history[-self.history_len:] if self.history_len > 0 else [],
history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature,
):
)):
if inum == 0:
history += [[prompt, stream_resp]]
else:
history[-1] = [prompt, stream_resp]
yield stream_resp, history
else:
......
......@@ -33,23 +33,23 @@ def get_answer(query, vs_path, history, mode):
if mode == "知识库问答":
if vs_path:
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=query, vs_path=vs_path, chat_history=history):
# source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
# {doc.page_content}
#
# <b>所属文件:</b>{doc.metadata["source"]}
# </details>""" for i, doc in enumerate(resp["source_documents"])])
# history[-1][-1] += source
query=query, vs_path=vs_path, chat_history=history):
source = "\n\n"
source += "".join(
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
f"""{doc.page_content}\n"""
f"""</details>"""
for i, doc in
enumerate(resp["source_documents"])])
history[-1][-1] += source
yield history, ""
else:
history = history + [[query, ""]]
for resp in local_doc_qa.llm._call(query):
for resp, history in local_doc_qa.llm._call(query, history):
history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
else:
history = history + [[query, ""]]
for resp in local_doc_qa.llm._call(query):
for resp, history in local_doc_qa.llm._call(query, history):
history[-1][-1] = resp
yield history, ""
......@@ -269,9 +269,10 @@ with gr.Blocks(css=block_css) as demo:
outputs=chatbot
)
demo.queue(concurrency_count=3
).launch(server_name='0.0.0.0',
server_port=7860,
show_api=False,
share=False,
inbrowser=False)
(demo
.queue(concurrency_count=3)
.launch(server_name='0.0.0.0',
server_port=7860,
show_api=False,
share=False,
inbrowser=False))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论