提交 391dc1d3 作者: hzg0601

debug for fastchat-openai-llm

......@@ -226,6 +226,10 @@ Web UI 可以实现如下功能:
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
- [x] [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1)
- [x] [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b)
- [x] [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [x] [lmsys/vicuna-13b-delta-v1.1](https://huggingface.co/lmsys/vicuna-13b-delta-v1.1)
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
- [x] 增加更多 Embedding 模型支持
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
......@@ -251,7 +255,7 @@ Web UI 可以实现如下功能:
- [x] VUE 前端
## 项目交流群
<img src="img/qr_code_42.jpg" alt="二维码" width="300" height="300" />
<img src="img/qr_code_44.jpg" alt="二维码" width="300" height="300" />
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
#encoding:utf-8
import argparse
import json
import os
......
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from typing import Any, List
class MyEmbeddings(HuggingFaceEmbeddings):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
texts = list(map(lambda x: x.replace("\n", " "), texts))
embeddings = self.client.encode(texts, normalize_embeddings=True)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
text = text.replace("\n", " ")
embedding = self.client.encode(text, normalize_embeddings=True)
return embedding.tolist()
from langchain.vectorstores import FAISS
from typing import Any, Callable, List, Optional, Tuple, Dict
from langchain.docstore.document import Document
from langchain.docstore.base import Docstore
from langchain.vectorstores.utils import maximal_marginal_relevance
from langchain.embeddings.base import Embeddings
import uuid
from langchain.docstore.in_memory import InMemoryDocstore
import numpy as np
def dependable_faiss_import() -> Any:
"""Import faiss if available, otherwise raise error."""
try:
import faiss
except ImportError:
raise ValueError(
"Could not import faiss python package. "
"Please install it with `pip install faiss` "
"or `pip install faiss-cpu` (depending on Python version)."
)
return faiss
class FAISSVS(FAISS):
def __init__(self,
embedding_function: Callable[..., Any],
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str]):
super().__init__(embedding_function, index, docstore, index_to_docstore_id)
def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Returns:
List of Documents with scores selected by maximal marginal relevance.
"""
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
# -1 happens when not enough docs are returned.
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
mmr_selected = maximal_marginal_relevance(
np.array([embedding], dtype=np.float32), embeddings, k=k
)
selected_indices = [indices[0][i] for i in mmr_selected]
selected_scores = [scores[0][i] for i in mmr_selected]
docs = []
for i, score in zip(selected_indices, selected_scores):
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, score))
return docs
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Returns:
List of Documents with scores selected by maximal marginal relevance.
"""
embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
return docs
@classmethod
def __from(
cls,
texts: List[str],
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> FAISS:
faiss = dependable_faiss_import()
index = faiss.IndexFlatIP(len(embeddings[0]))
index.add(np.array(embeddings, dtype=np.float32))
# # my code, for speeding up search
# quantizer = faiss.IndexFlatL2(len(embeddings[0]))
# index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
# index.train(np.array(embeddings, dtype=np.float32))
# index.add(np.array(embeddings, dtype=np.float32))
documents = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
docstore = InMemoryDocstore(
{index_to_id[i]: doc for i, doc in enumerate(documents)}
)
return cls(embedding.embed_query, index, docstore, index_to_id)
......@@ -246,8 +246,8 @@ LLM_HISTORY_LEN = 3
# 知识库检索时返回的匹配内容条数
VECTOR_SEARCH_TOP_K = 5
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
VECTOR_SEARCH_SCORE_THRESHOLD = 390
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,建议设置为500左右,经测试设置为小于500时,匹配结果更精准
VECTOR_SEARCH_SCORE_THRESHOLD = 500
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
......
......@@ -6,6 +6,7 @@ from queue import Queue
from threading import Thread
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.loader import LoaderCheckPoint
from pydantic import BaseModel
import torch
import transformers
......@@ -23,13 +24,12 @@ class ListenerToken:
self._scores = _scores
class AnswerResult:
class AnswerResult(BaseModel):
"""
消息实体
"""
history: List[List[str]] = []
llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream:
......@@ -167,8 +167,6 @@ class BaseAnswer(ABC):
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
for answerResult in generator:
if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult
@abstractmethod
......
......@@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": stream_resp}
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
self.checkPoint.clear_torch_cache()
else:
......@@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
from abc import ABC
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Collection
from typing import (
Any, Dict, List, Optional, Generator, Collection, Set,
Callable,
Tuple,
Union)
from models.loader import LoaderCheckPoint
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.base import (BaseAnswer,
......@@ -8,9 +13,26 @@ from models.base import (BaseAnswer,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from pydantic import Extra, Field, root_validator
from openai import (
ChatCompletion
)
import openai
import logging
import torch
import transformers
logger = logging.getLogger(__name__)
def _build_message_template() -> Dict[str, str]:
"""
......@@ -25,15 +47,26 @@ def _build_message_template() -> Dict[str, str]:
# 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = []
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = response
build_messages.append(user_build_message)
build_messages.append(system_build_message)
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = "You are a helpful assistant."
build_messages.append(system_build_message)
if history:
for i, (user, assistant) in enumerate(history):
if user:
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = user
build_messages.append(user_build_message)
if not assistant:
raise RuntimeError("历史数据结构不正确")
system_build_message = _build_message_template()
system_build_message['role'] = 'assistant'
system_build_message['content'] = assistant
build_messages.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
......@@ -43,6 +76,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
client: Any
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 6
api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b"
max_token: int = 10000
......@@ -108,6 +144,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
def call_model_name(self, model_name):
self.model_name = model_name
def _create_retry_decorator(self) -> Callable[[Any], Any]:
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _call(
self,
inputs: Dict[str, Any],
......@@ -121,32 +186,74 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
history = inputs.get(self.history_key, [])
streaming = inputs.get(self.streaming_key, False)
prompt = inputs[self.prompt_key]
stop = inputs.get("stop", "stop")
print(f"__call:{prompt}")
try:
import openai
# Not support yet
# openai.api_key = "EMPTY"
openai.api_key = self.api_key
openai.api_base = self.api_base_url
except ImportError:
self.client = openai.ChatCompletion
except AttributeError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=build_message_list(prompt,history=history)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
history += [[prompt, completion.choices[0].message.content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": completion.choices[0].message.content}
generate_with_callback(answer_result)
msg = build_message_list(prompt, history=history)
if streaming:
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
out_str = ""
for stream_resp in self.completion_with_retry(
messages=msg,
**params
):
role = stream_resp["choices"][0]["delta"].get("role", "")
token = stream_resp["choices"][0]["delta"].get("content", "")
out_str += token
history[-1] = [prompt, out_str]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": out_str}
generate_with_callback(answer_result)
else:
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
response = self.completion_with_retry(
messages=msg,
**params
)
role = response["choices"][0]["message"].get("role", "")
content = response["choices"][0]["message"].get("content", "")
history += [[prompt, content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": content}
generate_with_callback(answer_result)
if __name__ == "__main__":
chain = FastChatOpenAILLMChain()
chain.set_api_key("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")
# chain.set_api_base_url("https://api.openai.com/v1")
# chain.call_model_name("gpt-3.5-turbo")
answer_result_stream_result = chain({"streaming": True,
"prompt": "你好",
"history": []
})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
print(resp)
......@@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
history += [[prompt, reply]]
answer_result.history = history
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)
......@@ -11,7 +11,7 @@ beautifulsoup4
icetk
cpm_kernels
faiss-cpu
gradio==3.28.3
gradio==3.37.0
fastapi~=0.95.0
uvicorn~=0.21.1
pypinyin~=0.48.0
......
import streamlit as st
# from st_btn_select import st_btn_select
from streamlit_chatbox import st_chatbox
import tempfile
###### 从webui借用的代码 #####
###### 做了少量修改 #####
......@@ -23,6 +23,7 @@ def get_vs_list():
if not os.path.exists(KB_ROOT_PATH):
return lst_default
lst = os.listdir(KB_ROOT_PATH)
lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))]
if not lst:
return lst_default
lst.sort()
......@@ -31,7 +32,6 @@ def get_vs_list():
embedding_model_dict_list = list(embedding_model_dict.keys())
llm_model_dict_list = list(llm_model_dict.keys())
# flag_csv_logger = gr.CSVLogger()
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
......@@ -50,6 +50,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
history[-1][-1] += source
yield history, ""
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
local_doc_qa.top_k = vector_search_top_k
local_doc_qa.chunk_conent = chunk_conent
local_doc_qa.chunk_size = chunk_size
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
source = "\n\n"
......@@ -95,62 +98,15 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
# flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'):
local_doc_qa = LocalDocQA()
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
args_dict.update(model=llm_model)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins,
embedding_model=embedding_model)
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
except Exception as e:
logger.error(e)
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
if str(e) == "Unknown platform: darwin":
logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
" https://github.com/imClumsyPanda/langchain-ChatGLM")
else:
logger.info(reply)
return local_doc_qa
# 暂未使用到,先保留
# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
# try:
# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
# llm_model_ins.history_len = llm_history_len
# local_doc_qa.init_cfg(llm_model=llm_model_ins,
# embedding_model=embedding_model,
# top_k=top_k)
# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
# logger.info(model_status)
# except Exception as e:
# logger.error(e)
# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
# logger.info(model_status)
# return history + [[None, model_status]]
def get_vector_store(local_doc_qa, vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
filelist = []
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
if local_doc_qa.llm and local_doc_qa.embeddings:
qa = st.session_state.local_doc_qa
if qa.llm_model_chain and qa.embeddings:
if isinstance(files, list):
for file in files:
filename = os.path.split(file.name)[-1]
......@@ -158,10 +114,10 @@ def get_vector_store(local_doc_qa, vs_id, files, sentence_size, history, one_con
KB_ROOT_PATH, vs_id, "content", filename))
filelist.append(os.path.join(
KB_ROOT_PATH, vs_id, "content", filename))
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(
vs_path, loaded_files = qa.init_knowledge_vector_store(
filelist, vs_path, sentence_size)
else:
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
sentence_size)
if len(loaded_files):
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
......@@ -179,10 +135,7 @@ knowledge_base_test_mode_info = ("【注意】\n\n"
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
"4. 单条内容长度建议设置在100-150左右。\n\n"
"5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中,"
"本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。"
"相关参数将在后续版本中支持本界面直接修改。")
"4. 单条内容长度建议设置在100-150左右。")
webui_title = """
......@@ -194,7 +147,7 @@ webui_title = """
###### todo #####
# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
# 目前已经实现了local_doc_qa的全局化,后面要考虑shared
# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化
# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
......@@ -203,25 +156,11 @@ webui_title = """
###### 配置项 #####
class ST_CONFIG:
user_bg_color = '#77ff77'
user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7'
robot_bg_color = '#ccccee'
robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0'
default_mode = '知识库问答'
defalut_kb = ''
default_mode = "知识库问答"
default_kb = ""
###### #####
class MsgType:
'''
目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。
'''
TEXT = 1
IMAGE = 2
VIDEO = 3
AUDIO = 4
class TempFile:
'''
为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
......@@ -231,132 +170,54 @@ class TempFile:
self.name = path
def init_session():
st.session_state.setdefault('history', [])
# def get_query_params():
# '''
# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。
# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode
# 方便将固定的配置分享给特定的人。
# '''
# params = st.experimental_get_query_params()
# return {k: v[0] for k, v in params.items() if v}
def robot_say(msg, kb=''):
st.session_state['history'].append(
{'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb})
def user_say(msg):
st.session_state['history'].append(
{'is_user': True, 'type': MsgType.TEXT, 'content': msg})
def format_md(msg, is_user=False, bg_color='', margin='10%'):
'''
将文本消息格式化为markdown文本
'''
if is_user:
bg_color = bg_color or ST_CONFIG.user_bg_color
text = f'''
<div style="background:{bg_color};
margin-left:{margin};
word-break:break-all;
float:right;
padding:2%;
border-radius:2%;">
{msg}
</div>
'''
else:
bg_color = bg_color or ST_CONFIG.robot_bg_color
text = f'''
<div style="background:{bg_color};
margin-right:{margin};
word-break:break-all;
padding:2%;
border-radius:2%;">
{msg}
</div>
'''
return text
def message(msg,
is_user=False,
msg_type=MsgType.TEXT,
icon='',
bg_color='',
margin='10%',
kb='',
):
'''
渲染单条消息。目前仅支持文本
'''
cols = st.columns([1, 10, 1])
empty = cols[1].empty()
if is_user:
icon = icon or ST_CONFIG.user_icon
bg_color = bg_color or ST_CONFIG.user_bg_color
cols[2].image(icon, width=40)
if msg_type == MsgType.TEXT:
text = format_md(msg, is_user, bg_color, margin)
empty.markdown(text, unsafe_allow_html=True)
else:
raise RuntimeError('only support text message now.')
else:
icon = icon or ST_CONFIG.robot_icon
bg_color = bg_color or ST_CONFIG.robot_bg_color
cols[0].image(icon, width=40)
if kb:
cols[0].write(f'({kb})')
if msg_type == MsgType.TEXT:
text = format_md(msg, is_user, bg_color, margin)
empty.markdown(text, unsafe_allow_html=True)
else:
raise RuntimeError('only support text message now.')
return empty
def output_messages(
user_bg_color='',
robot_bg_color='',
user_icon='',
robot_icon='',
):
with chat_box.container():
last_response = None
for msg in st.session_state['history']:
bg_color = user_bg_color if msg['is_user'] else robot_bg_color
icon = user_icon if msg['is_user'] else robot_icon
empty = message(msg['content'],
is_user=msg['is_user'],
icon=icon,
msg_type=msg['type'],
bg_color=bg_color,
kb=msg.get('kb', '')
)
if not msg['is_user']:
last_response = empty
return last_response
@st.cache_resource(show_spinner=False, max_entries=1)
def load_model(llm_model: str, embedding_model: str):
def load_model(
llm_model: str = LLM_MODEL,
embedding_model: str = EMBEDDING_MODEL,
use_ptuning_v2: bool = USE_PTUNING_V2,
):
'''
对应init_model,利用streamlit cache避免模型重复加载
'''
local_doc_qa = init_model(llm_model, embedding_model)
robot_say('模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。\n请尽量不要刷新页面,以免模型出错或重复加载。')
local_doc_qa = LocalDocQA()
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
args_dict.update(model=llm_model)
if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
# shared.loaderCheckPoint.model_name is different by no_remote_model.
# if it is not set properly error occurs when reinit llm model(issue#473).
# as no_remote_model is removed from model_config, need workaround to set it automaticlly.
local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or ''
no_remote_model = os.path.isdir(local_model_path)
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
llm_model_ins.history_len = LLM_HISTORY_LEN
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins,
embedding_model=embedding_model)
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
except Exception as e:
logger.error(e)
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
if str(e) == "Unknown platform: darwin":
logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
" https://github.com/imClumsyPanda/langchain-ChatGLM")
else:
logger.info(reply)
return local_doc_qa
# @st.cache_data
def answer(query, vs_path='', history=[], mode='', score_threshold=0,
vector_search_top_k=5, chunk_conent=True, chunk_size=100, qa=None
vector_search_top_k=5, chunk_conent=True, chunk_size=100
):
'''
对应get_answer,--利用streamlit cache缓存相同问题的答案--
......@@ -365,48 +226,24 @@ def answer(query, vs_path='', history=[], mode='', score_threshold=0,
vector_search_top_k, chunk_conent, chunk_size)
def load_vector_store(
vs_id,
files,
sentence_size=100,
history=[],
one_conent=None,
one_content_segmentation=None,
):
return get_vector_store(
local_doc_qa,
vs_id,
files,
sentence_size,
history,
one_conent,
one_content_segmentation,
)
def use_kb_mode(m):
return m in ["知识库问答", "知识库测试"]
# main ui
st.set_page_config(webui_title, layout='wide')
init_session()
# params = get_query_params()
# llm_model = params.get('llm_model', LLM_MODEL)
# embedding_model = params.get('embedding_model', EMBEDDING_MODEL)
with st.spinner(f'正在加载模型({LLM_MODEL} + {EMBEDDING_MODEL}),请耐心等候...'):
local_doc_qa = load_model(LLM_MODEL, EMBEDDING_MODEL)
def use_kb_mode(m):
return m in ['知识库问答', '知识库测试']
chat_box = st_chatbox(greetings=["模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。"])
# 使用 help(st_chatbox) 查看自定义参数
# sidebar
modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试']
with st.sidebar:
def on_mode_change():
m = st.session_state.mode
robot_say(f'已切换到"{m}"模式')
chat_box.robot_say(f'已切换到"{m}"模式')
if m == '知识库测试':
robot_say(knowledge_base_test_mode_info)
chat_box.robot_say(knowledge_base_test_mode_info)
index = 0
try:
......@@ -416,7 +253,7 @@ with st.sidebar:
mode = st.selectbox('对话模式', modes, index,
on_change=on_mode_change, key='mode')
with st.expander('模型配置', '知识' not in mode):
with st.expander('模型配置', not use_kb_mode(mode)):
with st.form('model_config'):
index = 0
try:
......@@ -425,9 +262,8 @@ with st.sidebar:
pass
llm_model = st.selectbox('LLM模型', llm_model_dict_list, index)
no_remote_model = st.checkbox('加载本地模型', False)
use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False)
use_lora = st.checkbox('使用lora微调的权重', False)
try:
index = embedding_model_dict_list.index(EMBEDDING_MODEL)
except:
......@@ -437,44 +273,52 @@ with st.sidebar:
btn_load_model = st.form_submit_button('重新加载模型')
if btn_load_model:
local_doc_qa = load_model(llm_model, embedding_model)
local_doc_qa = load_model(llm_model, embedding_model, use_ptuning_v2)
history_len = st.slider(
"LLM对话轮数", 1, 50, LLM_HISTORY_LEN)
if mode in ['知识库问答', '知识库测试']:
if use_kb_mode(mode):
vs_list = get_vs_list()
vs_list.remove('新建知识库')
def on_new_kb():
name = st.session_state.kb_name
if name in vs_list:
st.error(f'名为“{name}”的知识库已存在。')
if not name:
st.sidebar.error(f'新建知识库名称不能为空!')
elif name in vs_list:
st.sidebar.error(f'名为“{name}”的知识库已存在。')
else:
vs_list.append(name)
st.session_state.vs_path = name
st.session_state.kb_name = ''
new_kb_dir = os.path.join(KB_ROOT_PATH, name)
if not os.path.exists(new_kb_dir):
os.makedirs(new_kb_dir)
st.sidebar.success(f'名为“{name}”的知识库创建成功,您可以开始添加文件。')
def on_vs_change():
robot_say(f'已加载知识库: {st.session_state.vs_path}')
chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}')
with st.expander('知识库配置', True):
cols = st.columns([12, 10])
kb_name = cols[0].text_input(
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed')
if 'kb_name' not in st.session_state:
st.session_state.kb_name = kb_name
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed', key='kb_name')
cols[1].button('新建知识库', on_click=on_new_kb)
index = 0
try:
index = vs_list.index(ST_CONFIG.default_kb)
except:
pass
vs_path = st.selectbox(
'选择知识库', vs_list, on_change=on_vs_change, key='vs_path')
'选择知识库', vs_list, index, on_change=on_vs_change, key='vs_path')
st.text('')
score_threshold = st.slider(
'知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD)
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
history_len = st.slider(
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
# local_doc_qa.llm.set_history_len(history_len)
chunk_conent = st.checkbox('启用上下文关联', False)
st.text('')
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE)
st.text('')
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
files = st.file_uploader('上传知识文件',
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
......@@ -487,56 +331,61 @@ with st.sidebar:
with open(file, 'wb') as fp:
fp.write(f.getvalue())
file_list.append(TempFile(file))
_, _, history = load_vector_store(
_, _, history = get_vector_store(
vs_path, file_list, sentence_size, [], None, None)
st.session_state.files = []
# main body
chat_box = st.empty()
with st.form('my_form', clear_on_submit=True):
# load model after params rendered
with st.spinner(f"正在加载模型({llm_model} + {embedding_model}),请耐心等候..."):
local_doc_qa = load_model(
llm_model,
embedding_model,
use_ptuning_v2,
)
local_doc_qa.llm_model_chain.history_len = history_len
if use_kb_mode(mode):
local_doc_qa.chunk_conent = chunk_conent
local_doc_qa.chunk_size = chunk_size
# local_doc_qa.llm_model_chain.temperature = temperature # 这样设置temperature似乎不起作用
st.session_state.local_doc_qa = local_doc_qa
# input form
with st.form("my_form", clear_on_submit=True):
cols = st.columns([8, 1])
question = cols[0].text_input(
question = cols[0].text_area(
'temp', key='input_question', label_visibility='collapsed')
def on_send():
q = st.session_state.input_question
if q:
user_say(q)
if mode == 'LLM 对话':
robot_say('正在思考...')
last_response = output_messages()
for history, _ in answer(q,
history=[],
mode=mode):
last_response.markdown(
format_md(history[-1][-1], False),
unsafe_allow_html=True
)
elif use_kb_mode(mode):
robot_say('正在思考...', vs_path)
last_response = output_messages()
for history, _ in answer(q,
vs_path=os.path.join(
KB_ROOT_PATH, vs_path, "vector_store"),
history=[],
mode=mode,
score_threshold=score_threshold,
vector_search_top_k=top_k,
chunk_conent=chunk_conent,
chunk_size=chunk_size):
last_response.markdown(
format_md(history[-1][-1], False, 'ligreen'),
unsafe_allow_html=True
)
else:
robot_say('正在思考...')
last_response = output_messages()
st.session_state['history'][-1]['content'] = history[-1][-1]
submit = cols[1].form_submit_button('发送', on_click=on_send)
output_messages()
# st.write(st.session_state['history'])
if cols[1].form_submit_button("发送"):
chat_box.user_say(question)
history = []
if mode == "LLM 对话":
chat_box.robot_say("正在思考...")
chat_box.output_messages()
for history, _ in answer(question,
history=[],
mode=mode):
chat_box.update_last_box_text(history[-1][-1])
elif use_kb_mode(mode):
chat_box.robot_say(f"正在查询 [{vs_path}] ...")
chat_box.output_messages()
for history, _ in answer(question,
vs_path=os.path.join(
KB_ROOT_PATH, vs_path, 'vector_store'),
history=[],
mode=mode,
score_threshold=score_threshold,
vector_search_top_k=top_k,
chunk_conent=chunk_conent,
chunk_size=chunk_size):
chat_box.update_last_box_text(history[-1][-1])
else:
chat_box.robot_say(f"正在执行Bing搜索...")
chat_box.output_messages()
for history, _ in answer(question,
history=[],
mode=mode):
chat_box.update_last_box_text(history[-1][-1])
# st.write(chat_box.history)
chat_box.output_messages()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论