提交 c5bc2178 作者: glide-the

修改模型生成的调用方式,兼容Chain调用

修改模型切换的bug
上级 ca13ab81
......@@ -384,8 +384,10 @@ async def chat(
],
),
):
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
......@@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
global local_doc_qa
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
app = FastAPI()
# Add CORS middleware to allow all origins
......
......@@ -18,6 +18,7 @@ from agent import bing_search
from langchain.docstore.document import Document
from functools import lru_cache
from textsplitter.zh_title_enhance import zh_title_enhance
from langchain.chains.base import Chain
# patch HuggingFaceEmbeddings to make it hashable
......@@ -119,7 +120,7 @@ def search_result2docs(search_results):
class LocalDocQA:
llm: BaseAnswer = None
llm_model_chain: Chain = None
embeddings: object = None
top_k: int = VECTOR_SEARCH_TOP_K
chunk_size: int = CHUNK_SIZE
......@@ -129,10 +130,10 @@ class LocalDocQA:
def init_cfg(self,
embedding_model: str = EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
llm_model: BaseAnswer = None,
llm_model: Chain = None,
top_k=VECTOR_SEARCH_TOP_K,
):
self.llm = llm_model
self.llm_model_chain = llm_model
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
model_kwargs={'device': embedding_device})
self.top_k = top_k
......@@ -236,8 +237,10 @@ class LocalDocQA:
else:
prompt = query
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
answer_result_stream_result = self.llm_model_chain(
{"prompt": prompt, "history": chat_history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
......@@ -276,8 +279,10 @@ class LocalDocQA:
result_docs = search_result2docs(results)
prompt = generate_prompt(result_docs, query)
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
answer_result_stream_result = self.llm_model_chain(
{"prompt": prompt, "history": chat_history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
......@@ -296,7 +301,7 @@ class LocalDocQA:
def update_file_from_vector_store(self,
filepath: str or List[str],
vs_path,
docs: List[Document],):
docs: List[Document], ):
vector_store = load_vector_store(vs_path, self.embeddings)
status = vector_store.update_doc(filepath, docs)
return status
......@@ -320,7 +325,6 @@ if __name__ == "__main__":
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=llm_model_ins)
......
......@@ -37,61 +37,67 @@ llm_model_dict = {
"name": "chatglm-6b-int4-qe",
"pretrained_model_name": "THUDM/chatglm-6b-int4-qe",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm-6b-int4": {
"name": "chatglm-6b-int4",
"pretrained_model_name": "THUDM/chatglm-6b-int4",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm-6b-int8": {
"name": "chatglm-6b-int8",
"pretrained_model_name": "THUDM/chatglm-6b-int8",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm-6b": {
"name": "chatglm-6b",
"pretrained_model_name": "THUDM/chatglm-6b",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm2-6b": {
"name": "chatglm2-6b",
"pretrained_model_name": "THUDM/chatglm2-6b",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm2-6b-int4": {
"name": "chatglm2-6b-int4",
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm2-6b-int8": {
"name": "chatglm2-6b-int8",
"pretrained_model_name": "THUDM/chatglm2-6b-int8",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatyuan": {
"name": "chatyuan",
"pretrained_model_name": "ClueAI/ChatYuan-large-v2",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
"moss": {
"name": "moss",
"pretrained_model_name": "fnlp/moss-moon-003-sft",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
"vicuna-13b-hf": {
"name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "LLamaLLM"
"provides": "LLamaLLMChain"
},
"vicuna-7b-hf": {
"name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "LLamaLLMChain"
},
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
......@@ -101,7 +107,7 @@ llm_model_dict = {
"name": "bloomz-7b1",
"pretrained_model_name": "bigscience/bloomz-7b1",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
......@@ -110,14 +116,14 @@ llm_model_dict = {
"name": "bloom-3b",
"pretrained_model_name": "bigscience/bloom-3b",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
"baichuan-7b": {
"name": "baichuan-7b",
"pretrained_model_name": "baichuan-inc/baichuan-7B",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
"ggml-vicuna-13b-1.1-q5": {
......@@ -131,7 +137,7 @@ llm_model_dict = {
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
"local_model_path": f'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/''',
"provides": "LLamaLLM"
"provides": "LLamaLLMChain"
},
# 通过 fastchat 调用的模型请参考如下格式
......@@ -139,7 +145,7 @@ llm_model_dict = {
"name": "chatglm-6b", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "chatglm-6b",
"local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
......@@ -147,7 +153,7 @@ llm_model_dict = {
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "chatglm2-6b",
"local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url"
},
......@@ -156,7 +162,7 @@ llm_model_dict = {
"name": "vicuna-13b-hf", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
......@@ -171,7 +177,7 @@ llm_model_dict = {
"openai-chatgpt-3.5": {
"name": "gpt-3.5-turbo",
"pretrained_model_name": "gpt-3.5-turbo",
"provides": "FastChatOpenAILLM",
"provides": "FastChatOpenAILLMChain",
"local_model_path": None,
"api_base_url": "https://api.openapi.com/v1",
"api_key": ""
......@@ -226,7 +232,7 @@ LLM_HISTORY_LEN = 3
VECTOR_SEARCH_TOP_K = 5
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
VECTOR_SEARCH_SCORE_THRESHOLD = 0
VECTOR_SEARCH_SCORE_THRESHOLD = 390
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
......
from .chatglm_llm import ChatGLM
from .llama_llm import LLamaLLM
from .moss_llm import MOSSLLM
from .fastchat_openai_llm import FastChatOpenAILLM
from .chatglm_llm import ChatGLMLLMChain
from .llama_llm import LLamaLLMChain
from .fastchat_openai_llm import FastChatOpenAILLMChain
from .moss_llm import MOSSLLMChain
from models.base.base import (
AnswerResult,
BaseAnswer
)
BaseAnswer,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from models.base.remote_rpc_model import (
RemoteRpcModel
)
__all__ = [
"AnswerResult",
"BaseAnswer",
"RemoteRpcModel",
"AnswerResultStream",
"AnswerResultQueueSentinelTokenListenerQueue"
]
from abc import ABC, abstractmethod
from typing import Optional, List
from typing import Any, Dict, List, Optional, Generator
import traceback
from collections import deque
from queue import Queue
from threading import Thread
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.loader import LoaderCheckPoint
import torch
import transformers
from models.loader import LoaderCheckPoint
class ListenerToken:
"""
观测结果
"""
input_ids: torch.LongTensor
_scores: torch.FloatTensor
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
self.input_ids = input_ids
self._scores = _scores
class AnswerResult:
......@@ -16,6 +29,123 @@ class AnswerResult:
"""
history: List[List[str]] = []
llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream:
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, answerResult: AnswerResult):
if self.callback_func is not None:
self.callback_func(answerResult)
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
"""
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
"""
listenerQueue: deque = deque(maxlen=1)
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}):
self.mfunc = func
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
结束条件包含如下
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
3、模型预测出错
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象,
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
2、消息为self.sentinel标识,抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
异步线程检测stop_now属性被更新,抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if self.stop_now:
raise ValueError
self.q.put(val)
def gen():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
self.thread = Thread(target=gen)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
"""
暂无实现
:return:
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
""" break 后会执行 """
self.stop_now = True
class BaseAnswer(ABC):
......@@ -25,17 +155,25 @@ class BaseAnswer(ABC):
@abstractmethod
def _check_point(self) -> LoaderCheckPoint:
"""Return _check_point of llm."""
def generatorAnswer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,) -> Generator[Any, str, bool]:
def generate_with_callback(callback=None, **kwargs):
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
self._generate_answer(**kwargs)
@property
@abstractmethod
def _history_len(self) -> int:
"""Return _history_len of llm."""
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs)
@abstractmethod
def set_history_len(self, history_len: int) -> None:
"""Return _history_len of llm."""
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
for answerResult in generator:
if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
@abstractmethod
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
pass
from abc import ABC
from langchain.llms.base import LLM
from typing import Optional, List
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
class ChatGLM(BaseAnswer, LLM, ABC):
class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
# 相关度
top_p = 0.4
# 候选词数量
top_k = 10
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "ChatGLM"
def _chain_type(self) -> str:
return "ChatGLMLLMChain"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def _history_len(self) -> int:
return self.history_len
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
:meta private:
"""
return [self.prompt_key]
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
print(f"__call:{prompt}")
response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer,
prompt,
history=[],
max_length=self.max_token,
temperature=self.temperature
)
print(f"response:{response}")
print(f"+++++++++++++++++++++++++++++++++++")
return response
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
stopping_criteria_list.append(listenerQueue)
if streaming:
history += [[]]
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
self.checkPoint.tokenizer,
prompt,
history=history[-self.history_len:-1] if self.history_len > 1 else [],
history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
)):
# self.checkPoint.clear_torch_cache()
history[-1] = [prompt, stream_resp]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": stream_resp}
yield answer_result
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
self.checkPoint.clear_torch_cache()
else:
response, _ = self.checkPoint.model.chat(
......@@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt,
history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
)
self.checkPoint.clear_torch_cache()
history += [[prompt, response]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
yield answer_result
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
from abc import ABC
import requests
from typing import Optional, List
from langchain.llms.base import LLM
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Collection
from models.loader import LoaderCheckPoint
from models.base import (RemoteRpcModel,
AnswerResult)
from typing import (
Collection,
Dict
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.base import (BaseAnswer,
RemoteRpcModel,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
def _build_message_template() -> Dict[str, str]:
......@@ -22,18 +22,42 @@ def _build_message_template() -> Dict[str, str]:
}
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = []
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = response
build_messages.append(user_build_message)
build_messages.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_messages.append(user_build_message)
return build_messages
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b"
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
checkPoint: LoaderCheckPoint = None
history = []
# history = []
history_len: int = 10
api_key: str = ""
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self,
checkPoint: LoaderCheckPoint = None,
# api_base_url:str="http://localhost:8000/v1",
......@@ -44,19 +68,28 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "FastChat"
def _chain_type(self) -> str:
return "LLamaLLMChain"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def _history_len(self) -> int:
return self.history_len
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _api_key(self) -> str:
......@@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
def call_model_name(self, model_name):
self.model_name = model_name
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
try:
import openai
# Not support yet
# openai.api_key = "EMPTY"
openai.key = self.api_key
openai.api_base = self.api_base_url
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=self.build_message_list(prompt)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
return completion.choices[0].message.content
# 将历史对话数组转换为文本格式
def build_message_list(self, query) -> Collection[Dict[str, str]]:
build_message_list: Collection[Dict[str, str]] = []
history = self.history[-self.history_len:] if self.history_len > 0 else []
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = response
build_message_list.append(user_build_message)
build_message_list.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_message_list.append(user_build_message)
return build_message_list
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
try:
import openai
# Not support yet
# openai.api_key = "EMPTY"
......@@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=self.build_message_list(prompt)
messages=build_message_list(prompt)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
history += [[prompt, completion.choices[0].message.content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": completion.choices[0].message.content}
yield answer_result
generate_with_callback(answer_result)
from abc import ABC
from langchain.llms.base import LLM
import random
import torch
import transformers
from abc import ABC
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from typing import Optional, List, Dict, Any,Union
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: Union[torch.LongTensor,list], scores: Union[torch.FloatTensor,list]) -> torch.FloatTensor:
def __call__(self, input_ids: Union[torch.LongTensor, list],
scores: Union[torch.FloatTensor, list]) -> torch.FloatTensor:
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
input_ids = torch.tensor(input_ids) if isinstance(input_ids,list) else input_ids
scores = torch.tensor(scores) if isinstance(scores,list) else scores
input_ids = torch.tensor(input_ids) if isinstance(input_ids, list) else input_ids
scores = torch.tensor(scores) if isinstance(scores, list) else scores
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
class LLamaLLM(BaseAnswer, LLM, ABC):
class LLamaLLMChain(BaseAnswer, Chain, ABC):
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 3
......@@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
min_length: int = 0
logits_processor: LogitsProcessorList = None
stopping_criteria: Optional[StoppingCriteriaList] = None
eos_token_id: Optional[int] = [2]
state: object = {'max_new_tokens': 50,
'seed': 1,
'temperature': 0, 'top_p': 0.1,
'top_k': 40, 'typical_p': 1,
'repetition_penalty': 1.2,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
'truncation_length': 2048, 'custom_stopping_strings': '',
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
'pre_layer': 0, 'gpu_memory_0': 0}
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "LLamaLLM"
def _chain_type(self) -> str:
return "LLamaLLMChain"
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _check_point(self) -> LoaderCheckPoint:
......@@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
formatted_history += "### Human:{}\n### Assistant:".format(query)
return formatted_history
def prepare_inputs_for_generation(self,
input_ids: torch.LongTensor):
"""
预生成注意力掩码和 输入序列中每个位置的索引的张量
# TODO 没有思路
:return:
"""
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
attention_mask = self.get_masks(input_ids, input_ids.device)
position_ids = self.get_position_ids(
input_ids,
device=input_ids.device,
mask_positions=mask_positions
)
return input_ids, position_ids, attention_mask
@property
def _history_len(self) -> int:
return self.history_len
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
# Create the StoppingCriteriaList with the stopping strings
self.stopping_criteria = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
self.stopping_criteria.append(listenerQueue)
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
soft_prompt = self.history_to_text(query=prompt, history=history)
if self.logits_processor is None:
self.logits_processor = LogitsProcessorList()
self.logits_processor.append(InvalidScoreLogitsProcessor())
......@@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
"logits_processor": self.logits_processor}
# 向量转换
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
input_ids = self.encode(soft_prompt, add_bos_token=self.checkPoint.tokenizer.add_bos_token,
truncation_length=self.max_new_tokens)
gen_kwargs.update({'inputs': input_ids})
# 注意力掩码
# gen_kwargs.update({'attention_mask': attention_mask})
# gen_kwargs.update({'position_ids': position_ids})
if self.stopping_criteria is None:
self.stopping_criteria = transformers.StoppingCriteriaList()
# 观测输出
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
......@@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
if "llama_cpp" in self.checkPoint.model.__str__():
import inspect
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args)&set(gen_kwargs.keys())
common_kwargs = {key:gen_kwargs[key] for key in common_kwargs_keys}
#? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
#?为什么会不支持GPU呢,不应该啊?
output_ids = torch.tensor([list(self.checkPoint.model.generate(input_id_i.cpu(),**common_kwargs)) for input_id_i in input_ids])
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args) & set(
gen_kwargs.keys())
common_kwargs = {key: gen_kwargs[key] for key in common_kwargs_keys}
# ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
# ?为什么会不支持GPU呢,不应该啊?
output_ids = torch.tensor(
[list(self.checkPoint.model.generate(input_id_i.cpu(), **common_kwargs)) for input_id_i in input_ids])
else:
output_ids = self.checkPoint.model.generate(**gen_kwargs)
......@@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
reply = self.decode(output_ids[0][-new_tokens:])
print(f"response:{reply}")
print(f"+++++++++++++++++++++++++++++++++++")
return reply
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
softprompt = self.history_to_text(prompt,history=history)
response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult()
answer_result.history = history + [[prompt, response]]
answer_result.llm_output = {"answer": response}
yield answer_result
history += [[prompt, reply]]
answer_result.history = history
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)
......@@ -20,6 +20,7 @@ class LoaderCheckPoint:
no_remote_model: bool = False
# 模型名称
model_name: str = None
pretrained_model_name: str = None
tokenizer: object = None
# 模型全路径
model_path: str = None
......@@ -67,48 +68,49 @@ class LoaderCheckPoint:
self.load_in_8bit = params.get('load_in_8bit', False)
self.bf16 = params.get('bf16', False)
def _load_model_config(self, model_name):
def _load_model_config(self):
if self.model_path:
self.model_path = re.sub("\s","",self.model_path)
self.model_path = re.sub("\s", "", self.model_path)
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
if self.no_remote_model:
raise ValueError(
"本地模型local_model_path未配置路径"
)
else:
checkpoint = self.pretrained_model_name
print(f"load_model_config {checkpoint}...")
try:
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
return model_config
except Exception as e:
print(e)
return checkpoint
def _load_model(self, model_name):
def _load_model(self):
"""
加载自定义位置的model
:param model_name:
:return:
"""
print(f"Loading {model_name}...")
t0 = time.time()
if self.model_path:
self.model_path = re.sub("\s","",self.model_path)
self.model_path = re.sub("\s", "", self.model_path)
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
if self.no_remote_model:
raise ValueError(
"本地模型local_model_path未配置路径"
)
else:
checkpoint = self.pretrained_model_name
print(f"Loading {checkpoint}...")
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
if 'chatglm' in model_name.lower() or "chatyuan" in model_name.lower():
if 'chatglm' in self.model_name.lower() or "chatyuan" in self.model_name.lower():
LoaderClass = AutoModel
else:
LoaderClass = AutoModelForCausalLM
......@@ -138,7 +140,7 @@ class LoaderCheckPoint:
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
trust_remote_code=True).half().to(self.llm_device)
else:
from accelerate import dispatch_model,infer_auto_device_map
from accelerate import dispatch_model, infer_auto_device_map
model = LoaderClass.from_pretrained(checkpoint,
config=self.model_config,
......@@ -146,10 +148,10 @@ class LoaderCheckPoint:
trust_remote_code=True).half()
# 可传入device_map自定义每张卡的部署情况
if self.device_map is None:
if 'chatglm' in model_name.lower():
if 'chatglm' in self.model_name.lower():
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
elif 'moss' in self.model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
else:
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
......@@ -166,9 +168,9 @@ class LoaderCheckPoint:
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
# 实测在bloom模型上如此
# self.device_map = infer_auto_device_map(model,
# dtype=torch.int8,
# no_split_module_classes=model._no_split_modules)
# self.device_map = infer_auto_device_map(model,
# dtype=torch.int8,
# no_split_module_classes=model._no_split_modules)
model = dispatch_model(model, device_map=self.device_map)
else:
......@@ -202,7 +204,7 @@ class LoaderCheckPoint:
# tokenizer = model.tokenizer
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
#* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
# * -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
return model, tokenizer
......@@ -231,7 +233,7 @@ class LoaderCheckPoint:
llm_int8_enable_fp32_cpu_offload=False)
with init_empty_weights():
model = LoaderClass.from_config(self.model_config,trust_remote_code = True)
model = LoaderClass.from_config(self.model_config, trust_remote_code=True)
model.tie_weights()
if self.device_map is not None:
params['device_map'] = self.device_map
......@@ -321,7 +323,7 @@ class LoaderCheckPoint:
return device_map
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
def moss_auto_configure_device_map(self, num_gpus: int, checkpoint) -> Dict[str, int]:
try:
from accelerate import init_empty_weights
......@@ -336,16 +338,6 @@ class LoaderCheckPoint:
"`pip install bitsandbytes``pip install accelerate`."
) from exc
if self.model_path:
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
raise ValueError(
"本地模型local_model_path未配置路径"
)
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
pretrained_model_name_or_path=checkpoint)
......@@ -452,7 +444,7 @@ class LoaderCheckPoint:
def reload_model(self):
self.unload_model()
self.model_config = self._load_model_config(self.model_name)
self.model_config = self._load_model_config()
if self.use_ptuning_v2:
try:
......@@ -464,7 +456,7 @@ class LoaderCheckPoint:
except Exception as e:
print("加载PrefixEncoder config.json失败")
self.model, self.tokenizer = self._load_model(self.model_name)
self.model, self.tokenizer = self._load_model()
if self.lora:
self._add_lora_to_model([self.lora])
......
from abc import ABC
from langchain.llms.base import LLM
from typing import Optional, List
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
import torch
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
META_INSTRUCTION = \
"""You are an AI assistant whose name is MOSS.
......@@ -20,41 +28,65 @@ META_INSTRUCTION = \
Capabilities and tools that MOSS can possess.
"""
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
class MOSSLLM(BaseAnswer, LLM, ABC):
class MOSSLLMChain(BaseAnswer, Chain, ABC):
max_token: int = 2048
temperature: float = 0.7
top_p = 0.8
# history = []
checkPoint: LoaderCheckPoint = None
history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "MOSS"
def _chain_type(self) -> str:
return "MOSSLLMChain"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def _history_len(self) -> int:
def output_keys(self) -> List[str]:
"""Will always return text key.
return self.history_len
:meta private:
"""
return [self.output_key]
def set_history_len(self, history_len: int) -> None:
self.history_len = history_len
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
if len(history) > 0:
history = history[-self.history_len:] if self.history_len > 0 else []
prompt_w_history = str(history)
......@@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
num_return_sequences=1,
eos_token_id=106068,
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True)
self.checkPoint.clear_torch_cache()
history += [[prompt, response]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
yield answer_result
generate_with_callback(answer_result)
......@@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if use_ptuning_v2:
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
# 如果指定了参数,则使用参数的配置
if llm_model:
llm_model_info = llm_model_dict[llm_model]
if loaderCheckPoint.no_remote_model:
loaderCheckPoint.model_name = llm_model_info['name']
else:
loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
......
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
import asyncio
from argparse import Namespace
from models.loader.args import parser
from models.loader import LoaderCheckPoint
import models.shared as shared
async def dispatch(args: Namespace):
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
history = [
("which city is this?", "tokyo"),
("why?", "she's japanese"),
]
for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
streaming=False):
resp = answer_result.llm_output["answer"]
print(resp)
if __name__ == '__main__':
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'fastchat-chatglm-6b', '--no-remote-model'])
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(dispatch(args))
......@@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], ""
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": query, "history": history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][-1] = resp
......@@ -101,11 +104,12 @@ def init_model():
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins)
generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator:
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
......@@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
filelist = []
if local_doc_qa.llm and local_doc_qa.embeddings:
if local_doc_qa.llm_model_chain and local_doc_qa.embeddings:
if isinstance(files, list):
for file in files:
filename = os.path.split(file.name)[-1]
......@@ -165,7 +169,7 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
def change_vs_name_input(vs_id, history):
if vs_id == "新建知识库":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history, \
gr.update(choices=[]), gr.update(visible=False)
else:
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
......@@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history):
def add_vs_name(vs_name, chatbot):
if vs_name is None or vs_name.strip() == "" :
if vs_name is None or vs_name.strip() == "":
vs_status = "知识库名称不能为空,请重新填写知识库名称"
chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
......@@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
def refresh_vs_list():
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
def delete_file(vs_id, files_to_delete, chatbot):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
......@@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
if "fail" in status:
vs_status = "文件删除失败。"
elif len(rested_files)>0:
elif len(rested_files) > 0:
vs_status = "文件删除成功。"
else:
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
logger.info(",".join(files_to_delete)+vs_status)
logger.info(",".join(files_to_delete) + vs_status)
chatbot = chatbot + [[None, vs_status]]
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
......@@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
status = f"成功删除知识库{vs_id}"
logger.info(status)
chatbot = chatbot + [[None, status]]
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(
visible=True), \
gr.update(visible=False), chatbot, gr.update(visible=False)
except Exception as e:
logger.error(e)
......@@ -333,7 +339,8 @@ default_theme_args = dict(
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
vs_path, file_status, model_status = gr.State(
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(
""), gr.State(
model_status)
gr.Markdown(webui_title)
with gr.Tab("对话"):
......
......@@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], ""
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": query, "history": history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][-1] = resp + (
......@@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
args_dict.update(model=llm_model)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins,
embedding_model=embedding_model)
generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator:
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
......@@ -468,7 +470,7 @@ with st.sidebar:
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
history_len = st.slider(
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
local_doc_qa.llm.set_history_len(history_len)
# local_doc_qa.llm.set_history_len(history_len)
chunk_conent = st.checkbox('启用上下文关联', False)
st.text('')
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论