提交 c5bc2178 作者: glide-the

修改模型生成的调用方式,兼容Chain调用

修改模型切换的bug
上级 ca13ab81
...@@ -384,8 +384,10 @@ async def chat( ...@@ -384,8 +384,10 @@ async def chat(
], ],
), ),
): ):
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history, answer_result_stream_result = local_doc_qa.llm_model_chain(
streaming=True): {"prompt": question, "history": history, "streaming": True})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"] resp = answer_result.llm_output["answer"]
history = answer_result.history history = answer_result.history
pass pass
...@@ -486,7 +488,6 @@ def api_start(host, port, **kwargs): ...@@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
global local_doc_qa global local_doc_qa
llm_model_ins = shared.loaderLLM() llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
app = FastAPI() app = FastAPI()
# Add CORS middleware to allow all origins # Add CORS middleware to allow all origins
......
...@@ -18,6 +18,7 @@ from agent import bing_search ...@@ -18,6 +18,7 @@ from agent import bing_search
from langchain.docstore.document import Document from langchain.docstore.document import Document
from functools import lru_cache from functools import lru_cache
from textsplitter.zh_title_enhance import zh_title_enhance from textsplitter.zh_title_enhance import zh_title_enhance
from langchain.chains.base import Chain
# patch HuggingFaceEmbeddings to make it hashable # patch HuggingFaceEmbeddings to make it hashable
...@@ -119,7 +120,7 @@ def search_result2docs(search_results): ...@@ -119,7 +120,7 @@ def search_result2docs(search_results):
class LocalDocQA: class LocalDocQA:
llm: BaseAnswer = None llm_model_chain: Chain = None
embeddings: object = None embeddings: object = None
top_k: int = VECTOR_SEARCH_TOP_K top_k: int = VECTOR_SEARCH_TOP_K
chunk_size: int = CHUNK_SIZE chunk_size: int = CHUNK_SIZE
...@@ -129,10 +130,10 @@ class LocalDocQA: ...@@ -129,10 +130,10 @@ class LocalDocQA:
def init_cfg(self, def init_cfg(self,
embedding_model: str = EMBEDDING_MODEL, embedding_model: str = EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE, embedding_device=EMBEDDING_DEVICE,
llm_model: BaseAnswer = None, llm_model: Chain = None,
top_k=VECTOR_SEARCH_TOP_K, top_k=VECTOR_SEARCH_TOP_K,
): ):
self.llm = llm_model self.llm_model_chain = llm_model
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
model_kwargs={'device': embedding_device}) model_kwargs={'device': embedding_device})
self.top_k = top_k self.top_k = top_k
...@@ -236,8 +237,10 @@ class LocalDocQA: ...@@ -236,8 +237,10 @@ class LocalDocQA:
else: else:
prompt = query prompt = query
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, answer_result_stream_result = self.llm_model_chain(
streaming=streaming): {"prompt": prompt, "history": chat_history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"] resp = answer_result.llm_output["answer"]
history = answer_result.history history = answer_result.history
history[-1][0] = query history[-1][0] = query
...@@ -276,8 +279,10 @@ class LocalDocQA: ...@@ -276,8 +279,10 @@ class LocalDocQA:
result_docs = search_result2docs(results) result_docs = search_result2docs(results)
prompt = generate_prompt(result_docs, query) prompt = generate_prompt(result_docs, query)
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, answer_result_stream_result = self.llm_model_chain(
streaming=streaming): {"prompt": prompt, "history": chat_history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"] resp = answer_result.llm_output["answer"]
history = answer_result.history history = answer_result.history
history[-1][0] = query history[-1][0] = query
...@@ -296,7 +301,7 @@ class LocalDocQA: ...@@ -296,7 +301,7 @@ class LocalDocQA:
def update_file_from_vector_store(self, def update_file_from_vector_store(self,
filepath: str or List[str], filepath: str or List[str],
vs_path, vs_path,
docs: List[Document],): docs: List[Document], ):
vector_store = load_vector_store(vs_path, self.embeddings) vector_store = load_vector_store(vs_path, self.embeddings)
status = vector_store.update_doc(filepath, docs) status = vector_store.update_doc(filepath, docs)
return status return status
...@@ -320,7 +325,6 @@ if __name__ == "__main__": ...@@ -320,7 +325,6 @@ if __name__ == "__main__":
args_dict = vars(args) args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict) shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM() llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=llm_model_ins) local_doc_qa.init_cfg(llm_model=llm_model_ins)
......
...@@ -37,61 +37,67 @@ llm_model_dict = { ...@@ -37,61 +37,67 @@ llm_model_dict = {
"name": "chatglm-6b-int4-qe", "name": "chatglm-6b-int4-qe",
"pretrained_model_name": "THUDM/chatglm-6b-int4-qe", "pretrained_model_name": "THUDM/chatglm-6b-int4-qe",
"local_model_path": None, "local_model_path": None,
"provides": "ChatGLM" "provides": "ChatGLMLLMChain"
}, },
"chatglm-6b-int4": { "chatglm-6b-int4": {
"name": "chatglm-6b-int4", "name": "chatglm-6b-int4",
"pretrained_model_name": "THUDM/chatglm-6b-int4", "pretrained_model_name": "THUDM/chatglm-6b-int4",
"local_model_path": None, "local_model_path": None,
"provides": "ChatGLM" "provides": "ChatGLMLLMChain"
}, },
"chatglm-6b-int8": { "chatglm-6b-int8": {
"name": "chatglm-6b-int8", "name": "chatglm-6b-int8",
"pretrained_model_name": "THUDM/chatglm-6b-int8", "pretrained_model_name": "THUDM/chatglm-6b-int8",
"local_model_path": None, "local_model_path": None,
"provides": "ChatGLM" "provides": "ChatGLMLLMChain"
}, },
"chatglm-6b": { "chatglm-6b": {
"name": "chatglm-6b", "name": "chatglm-6b",
"pretrained_model_name": "THUDM/chatglm-6b", "pretrained_model_name": "THUDM/chatglm-6b",
"local_model_path": None, "local_model_path": None,
"provides": "ChatGLM" "provides": "ChatGLMLLMChain"
}, },
"chatglm2-6b": { "chatglm2-6b": {
"name": "chatglm2-6b", "name": "chatglm2-6b",
"pretrained_model_name": "THUDM/chatglm2-6b", "pretrained_model_name": "THUDM/chatglm2-6b",
"local_model_path": None, "local_model_path": None,
"provides": "ChatGLM" "provides": "ChatGLMLLMChain"
}, },
"chatglm2-6b-int4": { "chatglm2-6b-int4": {
"name": "chatglm2-6b-int4", "name": "chatglm2-6b-int4",
"pretrained_model_name": "THUDM/chatglm2-6b-int4", "pretrained_model_name": "THUDM/chatglm2-6b-int4",
"local_model_path": None, "local_model_path": None,
"provides": "ChatGLM" "provides": "ChatGLMLLMChain"
}, },
"chatglm2-6b-int8": { "chatglm2-6b-int8": {
"name": "chatglm2-6b-int8", "name": "chatglm2-6b-int8",
"pretrained_model_name": "THUDM/chatglm2-6b-int8", "pretrained_model_name": "THUDM/chatglm2-6b-int8",
"local_model_path": None, "local_model_path": None,
"provides": "ChatGLM" "provides": "ChatGLMLLMChain"
}, },
"chatyuan": { "chatyuan": {
"name": "chatyuan", "name": "chatyuan",
"pretrained_model_name": "ClueAI/ChatYuan-large-v2", "pretrained_model_name": "ClueAI/ChatYuan-large-v2",
"local_model_path": None, "local_model_path": None,
"provides": "MOSSLLM" "provides": "MOSSLLMChain"
}, },
"moss": { "moss": {
"name": "moss", "name": "moss",
"pretrained_model_name": "fnlp/moss-moon-003-sft", "pretrained_model_name": "fnlp/moss-moon-003-sft",
"local_model_path": None, "local_model_path": None,
"provides": "MOSSLLM" "provides": "MOSSLLMChain"
}, },
"vicuna-13b-hf": { "vicuna-13b-hf": {
"name": "vicuna-13b-hf", "name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None, "local_model_path": None,
"provides": "LLamaLLM" "provides": "LLamaLLMChain"
},
"vicuna-7b-hf": {
"name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "LLamaLLMChain"
}, },
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数 # 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的, # 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
...@@ -101,7 +107,7 @@ llm_model_dict = { ...@@ -101,7 +107,7 @@ llm_model_dict = {
"name": "bloomz-7b1", "name": "bloomz-7b1",
"pretrained_model_name": "bigscience/bloomz-7b1", "pretrained_model_name": "bigscience/bloomz-7b1",
"local_model_path": None, "local_model_path": None,
"provides": "MOSSLLM" "provides": "MOSSLLMChain"
}, },
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢 # 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
...@@ -110,14 +116,14 @@ llm_model_dict = { ...@@ -110,14 +116,14 @@ llm_model_dict = {
"name": "bloom-3b", "name": "bloom-3b",
"pretrained_model_name": "bigscience/bloom-3b", "pretrained_model_name": "bigscience/bloom-3b",
"local_model_path": None, "local_model_path": None,
"provides": "MOSSLLM" "provides": "MOSSLLMChain"
}, },
"baichuan-7b": { "baichuan-7b": {
"name": "baichuan-7b", "name": "baichuan-7b",
"pretrained_model_name": "baichuan-inc/baichuan-7B", "pretrained_model_name": "baichuan-inc/baichuan-7B",
"local_model_path": None, "local_model_path": None,
"provides": "MOSSLLM" "provides": "MOSSLLMChain"
}, },
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204 # llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
"ggml-vicuna-13b-1.1-q5": { "ggml-vicuna-13b-1.1-q5": {
...@@ -131,7 +137,7 @@ llm_model_dict = { ...@@ -131,7 +137,7 @@ llm_model_dict = {
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装 # 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容 # 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
"local_model_path": f'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/''', "local_model_path": f'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/''',
"provides": "LLamaLLM" "provides": "LLamaLLMChain"
}, },
# 通过 fastchat 调用的模型请参考如下格式 # 通过 fastchat 调用的模型请参考如下格式
...@@ -139,7 +145,7 @@ llm_model_dict = { ...@@ -139,7 +145,7 @@ llm_model_dict = {
"name": "chatglm-6b", # "name"修改为fastchat服务中的"model_name" "name": "chatglm-6b", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "chatglm-6b", "pretrained_model_name": "chatglm-6b",
"local_model_path": None, "local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM" "provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url" "api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY" "api_key": "EMPTY"
}, },
...@@ -147,7 +153,7 @@ llm_model_dict = { ...@@ -147,7 +153,7 @@ llm_model_dict = {
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name" "name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "chatglm2-6b", "pretrained_model_name": "chatglm2-6b",
"local_model_path": None, "local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM" "provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url" "api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url"
}, },
...@@ -156,7 +162,7 @@ llm_model_dict = { ...@@ -156,7 +162,7 @@ llm_model_dict = {
"name": "vicuna-13b-hf", # "name"修改为fastchat服务中的"model_name" "name": "vicuna-13b-hf", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None, "local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM" "provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url" "api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY" "api_key": "EMPTY"
}, },
...@@ -171,7 +177,7 @@ llm_model_dict = { ...@@ -171,7 +177,7 @@ llm_model_dict = {
"openai-chatgpt-3.5": { "openai-chatgpt-3.5": {
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"pretrained_model_name": "gpt-3.5-turbo", "pretrained_model_name": "gpt-3.5-turbo",
"provides": "FastChatOpenAILLM", "provides": "FastChatOpenAILLMChain",
"local_model_path": None, "local_model_path": None,
"api_base_url": "https://api.openapi.com/v1", "api_base_url": "https://api.openapi.com/v1",
"api_key": "" "api_key": ""
...@@ -226,7 +232,7 @@ LLM_HISTORY_LEN = 3 ...@@ -226,7 +232,7 @@ LLM_HISTORY_LEN = 3
VECTOR_SEARCH_TOP_K = 5 VECTOR_SEARCH_TOP_K = 5
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准 # 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
VECTOR_SEARCH_SCORE_THRESHOLD = 0 VECTOR_SEARCH_SCORE_THRESHOLD = 390
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
......
from .chatglm_llm import ChatGLM from .chatglm_llm import ChatGLMLLMChain
from .llama_llm import LLamaLLM from .llama_llm import LLamaLLMChain
from .moss_llm import MOSSLLM from .fastchat_openai_llm import FastChatOpenAILLMChain
from .fastchat_openai_llm import FastChatOpenAILLM from .moss_llm import MOSSLLMChain
from models.base.base import ( from models.base.base import (
AnswerResult, AnswerResult,
BaseAnswer BaseAnswer,
) AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from models.base.remote_rpc_model import ( from models.base.remote_rpc_model import (
RemoteRpcModel RemoteRpcModel
) )
__all__ = [ __all__ = [
"AnswerResult", "AnswerResult",
"BaseAnswer", "BaseAnswer",
"RemoteRpcModel", "RemoteRpcModel",
"AnswerResultStream",
"AnswerResultQueueSentinelTokenListenerQueue"
] ]
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, List from typing import Any, Dict, List, Optional, Generator
import traceback import traceback
from collections import deque from collections import deque
from queue import Queue from queue import Queue
from threading import Thread from threading import Thread
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.loader import LoaderCheckPoint
import torch import torch
import transformers import transformers
from models.loader import LoaderCheckPoint
class ListenerToken:
"""
观测结果
"""
input_ids: torch.LongTensor
_scores: torch.FloatTensor
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
self.input_ids = input_ids
self._scores = _scores
class AnswerResult: class AnswerResult:
...@@ -16,6 +29,123 @@ class AnswerResult: ...@@ -16,6 +29,123 @@ class AnswerResult:
""" """
history: List[List[str]] = [] history: List[List[str]] = []
llm_output: Optional[dict] = None llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream:
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, answerResult: AnswerResult):
if self.callback_func is not None:
self.callback_func(answerResult)
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
"""
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
"""
listenerQueue: deque = deque(maxlen=1)
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}):
self.mfunc = func
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
结束条件包含如下
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
3、模型预测出错
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象,
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
2、消息为self.sentinel标识,抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
异步线程检测stop_now属性被更新,抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if self.stop_now:
raise ValueError
self.q.put(val)
def gen():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
self.thread = Thread(target=gen)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
"""
暂无实现
:return:
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
""" break 后会执行 """
self.stop_now = True
class BaseAnswer(ABC): class BaseAnswer(ABC):
...@@ -25,17 +155,25 @@ class BaseAnswer(ABC): ...@@ -25,17 +155,25 @@ class BaseAnswer(ABC):
@abstractmethod @abstractmethod
def _check_point(self) -> LoaderCheckPoint: def _check_point(self) -> LoaderCheckPoint:
"""Return _check_point of llm.""" """Return _check_point of llm."""
def generatorAnswer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,) -> Generator[Any, str, bool]:
def generate_with_callback(callback=None, **kwargs):
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
self._generate_answer(**kwargs)
@property def generate_with_streaming(**kwargs):
@abstractmethod return Iteratorize(generate_with_callback, kwargs)
def _history_len(self) -> int:
"""Return _history_len of llm."""
@abstractmethod with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
def set_history_len(self, history_len: int) -> None: for answerResult in generator:
"""Return _history_len of llm.""" if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult
def generatorAnswer(self, prompt: str, @abstractmethod
history: List[List[str]] = [], def _generate_answer(self,
streaming: bool = False): inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
pass pass
from abc import ABC from abc import ABC
from langchain.llms.base import LLM from langchain.chains.base import Chain
from typing import Optional, List from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult) AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
class ChatGLM(BaseAnswer, LLM, ABC): class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
max_token: int = 10000 max_token: int = 10000
temperature: float = 0.01 temperature: float = 0.01
top_p = 0.9 # 相关度
top_p = 0.4
# 候选词数量
top_k = 10
checkPoint: LoaderCheckPoint = None checkPoint: LoaderCheckPoint = None
# history = [] # history = []
history_len: int = 10 history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None): def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__() super().__init__()
self.checkPoint = checkPoint self.checkPoint = checkPoint
@property @property
def _llm_type(self) -> str: def _chain_type(self) -> str:
return "ChatGLM" return "ChatGLMLLMChain"
@property @property
def _check_point(self) -> LoaderCheckPoint: def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint return self.checkPoint
@property @property
def _history_len(self) -> int: def input_keys(self) -> List[str]:
return self.history_len """Will be whatever keys the prompt expects.
def set_history_len(self, history_len: int = 10) -> None: :meta private:
self.history_len = history_len """
return [self.prompt_key]
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: @property
print(f"__call:{prompt}") def output_keys(self) -> List[str]:
response, _ = self.checkPoint.model.chat( """Will always return text key.
self.checkPoint.tokenizer,
prompt,
history=[],
max_length=self.max_token,
temperature=self.temperature
)
print(f"response:{response}")
print(f"+++++++++++++++++++++++++++++++++++")
return response
def generatorAnswer(self, prompt: str, :meta private:
history: List[List[str]] = [], """
streaming: bool = False): return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
stopping_criteria_list.append(listenerQueue)
if streaming: if streaming:
history += [[]] history += [[]]
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat( for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
self.checkPoint.tokenizer, self.checkPoint.tokenizer,
prompt, prompt,
history=history[-self.history_len:-1] if self.history_len > 1 else [], history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
)): )):
# self.checkPoint.clear_torch_cache() # self.checkPoint.clear_torch_cache()
history[-1] = [prompt, stream_resp] history[-1] = [prompt, stream_resp]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": stream_resp} answer_result.llm_output = {"answer": stream_resp}
yield answer_result if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
self.checkPoint.clear_torch_cache() self.checkPoint.clear_torch_cache()
else: else:
response, _ = self.checkPoint.model.chat( response, _ = self.checkPoint.model.chat(
...@@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC): ...@@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt, prompt,
history=history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
) )
self.checkPoint.clear_torch_cache() self.checkPoint.clear_torch_cache()
history += [[prompt, response]] history += [[prompt, response]]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
yield answer_result if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
from abc import ABC from abc import ABC
import requests from langchain.chains.base import Chain
from typing import Optional, List from typing import Any, Dict, List, Optional, Generator, Collection
from langchain.llms.base import LLM
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (RemoteRpcModel, from langchain.callbacks.manager import CallbackManagerForChainRun
AnswerResult) from models.base import (BaseAnswer,
from typing import ( RemoteRpcModel,
Collection, AnswerResult,
Dict AnswerResultStream,
) AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
def _build_message_template() -> Dict[str, str]: def _build_message_template() -> Dict[str, str]:
...@@ -22,18 +22,42 @@ def _build_message_template() -> Dict[str, str]: ...@@ -22,18 +22,42 @@ def _build_message_template() -> Dict[str, str]:
} }
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): # 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = []
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = response
build_messages.append(user_build_message)
build_messages.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_messages.append(user_build_message)
return build_messages
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
api_base_url: str = "http://localhost:8000/v1" api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b" model_name: str = "chatglm-6b"
max_token: int = 10000 max_token: int = 10000
temperature: float = 0.01 temperature: float = 0.01
top_p = 0.9 top_p = 0.9
checkPoint: LoaderCheckPoint = None checkPoint: LoaderCheckPoint = None
history = [] # history = []
history_len: int = 10 history_len: int = 10
api_key: str = "" api_key: str = ""
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, def __init__(self,
checkPoint: LoaderCheckPoint = None, checkPoint: LoaderCheckPoint = None,
# api_base_url:str="http://localhost:8000/v1", # api_base_url:str="http://localhost:8000/v1",
...@@ -44,19 +68,28 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -44,19 +68,28 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
self.checkPoint = checkPoint self.checkPoint = checkPoint
@property @property
def _llm_type(self) -> str: def _chain_type(self) -> str:
return "FastChat" return "LLamaLLMChain"
@property @property
def _check_point(self) -> LoaderCheckPoint: def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint return self.checkPoint
@property @property
def _history_len(self) -> int: def input_keys(self) -> List[str]:
return self.history_len """Will be whatever keys the prompt expects.
def set_history_len(self, history_len: int = 10) -> None: :meta private:
self.history_len = history_len """
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property @property
def _api_key(self) -> str: def _api_key(self) -> str:
...@@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
def call_model_name(self, model_name): def call_model_name(self, model_name):
self.model_name = model_name self.model_name = model_name
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}") print(f"__call:{prompt}")
try: try:
import openai
# Not support yet
# openai.api_key = "EMPTY"
openai.key = self.api_key
openai.api_base = self.api_base_url
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=self.build_message_list(prompt)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
return completion.choices[0].message.content
# 将历史对话数组转换为文本格式
def build_message_list(self, query) -> Collection[Dict[str, str]]:
build_message_list: Collection[Dict[str, str]] = []
history = self.history[-self.history_len:] if self.history_len > 0 else []
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = response
build_message_list.append(user_build_message)
build_message_list.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_message_list.append(user_build_message)
return build_message_list
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
try:
import openai import openai
# Not support yet # Not support yet
# openai.api_key = "EMPTY" # openai.api_key = "EMPTY"
...@@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# create a chat completion # create a chat completion
completion = openai.ChatCompletion.create( completion = openai.ChatCompletion.create(
model=self.model_name, model=self.model_name,
messages=self.build_message_list(prompt) messages=build_message_list(prompt)
) )
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
history += [[prompt, completion.choices[0].message.content]] history += [[prompt, completion.choices[0].message.content]]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": completion.choices[0].message.content} answer_result.llm_output = {"answer": completion.choices[0].message.content}
generate_with_callback(answer_result)
yield answer_result
from abc import ABC
from langchain.llms.base import LLM from abc import ABC
import random from langchain.chains.base import Chain
import torch from typing import Any, Dict, List, Optional, Generator, Union
import transformers from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from typing import Optional, List, Dict, Any,Union
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult) AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
class InvalidScoreLogitsProcessor(LogitsProcessor): class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: Union[torch.LongTensor,list], scores: Union[torch.FloatTensor,list]) -> torch.FloatTensor: def __call__(self, input_ids: Union[torch.LongTensor, list],
scores: Union[torch.FloatTensor, list]) -> torch.FloatTensor:
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor # llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
input_ids = torch.tensor(input_ids) if isinstance(input_ids,list) else input_ids input_ids = torch.tensor(input_ids) if isinstance(input_ids, list) else input_ids
scores = torch.tensor(scores) if isinstance(scores,list) else scores scores = torch.tensor(scores) if isinstance(scores, list) else scores
if torch.isnan(scores).any() or torch.isinf(scores).any(): if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_() scores.zero_()
scores[..., 5] = 5e4 scores[..., 5] = 5e4
return scores return scores
class LLamaLLM(BaseAnswer, LLM, ABC): class LLamaLLMChain(BaseAnswer, Chain, ABC):
checkPoint: LoaderCheckPoint = None checkPoint: LoaderCheckPoint = None
# history = [] # history = []
history_len: int = 3 history_len: int = 3
...@@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
min_length: int = 0 min_length: int = 0
logits_processor: LogitsProcessorList = None logits_processor: LogitsProcessorList = None
stopping_criteria: Optional[StoppingCriteriaList] = None stopping_criteria: Optional[StoppingCriteriaList] = None
eos_token_id: Optional[int] = [2] streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
state: object = {'max_new_tokens': 50, prompt_key: str = "prompt" #: :meta private:
'seed': 1, output_key: str = "answer_result_stream" #: :meta private:
'temperature': 0, 'top_p': 0.1,
'top_k': 40, 'typical_p': 1,
'repetition_penalty': 1.2,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
'truncation_length': 2048, 'custom_stopping_strings': '',
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
'pre_layer': 0, 'gpu_memory_0': 0}
def __init__(self, checkPoint: LoaderCheckPoint = None): def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__() super().__init__()
self.checkPoint = checkPoint self.checkPoint = checkPoint
@property @property
def _llm_type(self) -> str: def _chain_type(self) -> str:
return "LLamaLLM" return "LLamaLLMChain"
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property @property
def _check_point(self) -> LoaderCheckPoint: def _check_point(self) -> LoaderCheckPoint:
...@@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
formatted_history += "### Human:{}\n### Assistant:".format(query) formatted_history += "### Human:{}\n### Assistant:".format(query)
return formatted_history return formatted_history
def prepare_inputs_for_generation(self, def _call(
input_ids: torch.LongTensor): self,
""" inputs: Dict[str, Any],
预生成注意力掩码和 输入序列中每个位置的索引的张量 run_manager: Optional[CallbackManagerForChainRun] = None,
# TODO 没有思路 ) -> Dict[str, Generator]:
:return: generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
""" return {self.output_key: generator}
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device) def _generate_answer(self,
inputs: Dict[str, Any],
attention_mask = self.get_masks(input_ids, input_ids.device) run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
position_ids = self.get_position_ids(
input_ids, history = inputs[self.history_key]
device=input_ids.device, streaming = inputs[self.streaming_key]
mask_positions=mask_positions prompt = inputs[self.prompt_key]
)
return input_ids, position_ids, attention_mask
@property
def _history_len(self) -> int:
return self.history_len
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
print(f"__call:{prompt}") print(f"__call:{prompt}")
# Create the StoppingCriteriaList with the stopping strings
self.stopping_criteria = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
self.stopping_criteria.append(listenerQueue)
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
soft_prompt = self.history_to_text(query=prompt, history=history)
if self.logits_processor is None: if self.logits_processor is None:
self.logits_processor = LogitsProcessorList() self.logits_processor = LogitsProcessorList()
self.logits_processor.append(InvalidScoreLogitsProcessor()) self.logits_processor.append(InvalidScoreLogitsProcessor())
...@@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
"logits_processor": self.logits_processor} "logits_processor": self.logits_processor}
# 向量转换 # 向量转换
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens) input_ids = self.encode(soft_prompt, add_bos_token=self.checkPoint.tokenizer.add_bos_token,
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids) truncation_length=self.max_new_tokens)
gen_kwargs.update({'inputs': input_ids}) gen_kwargs.update({'inputs': input_ids})
# 注意力掩码
# gen_kwargs.update({'attention_mask': attention_mask})
# gen_kwargs.update({'position_ids': position_ids})
if self.stopping_criteria is None:
self.stopping_criteria = transformers.StoppingCriteriaList()
# 观测输出 # 观测输出
gen_kwargs.update({'stopping_criteria': self.stopping_criteria}) gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误 # llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
...@@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
if "llama_cpp" in self.checkPoint.model.__str__(): if "llama_cpp" in self.checkPoint.model.__str__():
import inspect import inspect
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args)&set(gen_kwargs.keys()) common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args) & set(
common_kwargs = {key:gen_kwargs[key] for key in common_kwargs_keys} gen_kwargs.keys())
#? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣 common_kwargs = {key: gen_kwargs[key] for key in common_kwargs_keys}
#?为什么会不支持GPU呢,不应该啊? # ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
output_ids = torch.tensor([list(self.checkPoint.model.generate(input_id_i.cpu(),**common_kwargs)) for input_id_i in input_ids]) # ?为什么会不支持GPU呢,不应该啊?
output_ids = torch.tensor(
[list(self.checkPoint.model.generate(input_id_i.cpu(), **common_kwargs)) for input_id_i in input_ids])
else: else:
output_ids = self.checkPoint.model.generate(**gen_kwargs) output_ids = self.checkPoint.model.generate(**gen_kwargs)
...@@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
reply = self.decode(output_ids[0][-new_tokens:]) reply = self.decode(output_ids[0][-new_tokens:])
print(f"response:{reply}") print(f"response:{reply}")
print(f"+++++++++++++++++++++++++++++++++++") print(f"+++++++++++++++++++++++++++++++++++")
return reply
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
softprompt = self.history_to_text(prompt,history=history)
response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history + [[prompt, response]] history += [[prompt, reply]]
answer_result.llm_output = {"answer": response} answer_result.history = history
yield answer_result if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)
...@@ -20,6 +20,7 @@ class LoaderCheckPoint: ...@@ -20,6 +20,7 @@ class LoaderCheckPoint:
no_remote_model: bool = False no_remote_model: bool = False
# 模型名称 # 模型名称
model_name: str = None model_name: str = None
pretrained_model_name: str = None
tokenizer: object = None tokenizer: object = None
# 模型全路径 # 模型全路径
model_path: str = None model_path: str = None
...@@ -67,48 +68,49 @@ class LoaderCheckPoint: ...@@ -67,48 +68,49 @@ class LoaderCheckPoint:
self.load_in_8bit = params.get('load_in_8bit', False) self.load_in_8bit = params.get('load_in_8bit', False)
self.bf16 = params.get('bf16', False) self.bf16 = params.get('bf16', False)
def _load_model_config(self):
def _load_model_config(self, model_name):
if self.model_path: if self.model_path:
self.model_path = re.sub("\s","",self.model_path) self.model_path = re.sub("\s", "", self.model_path)
checkpoint = Path(f'{self.model_path}') checkpoint = Path(f'{self.model_path}')
else: else:
if not self.no_remote_model: if self.no_remote_model:
checkpoint = model_name
else:
raise ValueError( raise ValueError(
"本地模型local_model_path未配置路径" "本地模型local_model_path未配置路径"
) )
else:
checkpoint = self.pretrained_model_name
print(f"load_model_config {checkpoint}...")
try: try:
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
return model_config return model_config
except Exception as e: except Exception as e:
print(e) print(e)
return checkpoint return checkpoint
def _load_model(self, model_name): def _load_model(self):
""" """
加载自定义位置的model 加载自定义位置的model
:param model_name:
:return: :return:
""" """
print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
if self.model_path: if self.model_path:
self.model_path = re.sub("\s","",self.model_path) self.model_path = re.sub("\s", "", self.model_path)
checkpoint = Path(f'{self.model_path}') checkpoint = Path(f'{self.model_path}')
else: else:
if not self.no_remote_model: if self.no_remote_model:
checkpoint = model_name
else:
raise ValueError( raise ValueError(
"本地模型local_model_path未配置路径" "本地模型local_model_path未配置路径"
) )
else:
checkpoint = self.pretrained_model_name
print(f"Loading {checkpoint}...")
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0 self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
if 'chatglm' in model_name.lower() or "chatyuan" in model_name.lower(): if 'chatglm' in self.model_name.lower() or "chatyuan" in self.model_name.lower():
LoaderClass = AutoModel LoaderClass = AutoModel
else: else:
LoaderClass = AutoModelForCausalLM LoaderClass = AutoModelForCausalLM
...@@ -138,7 +140,7 @@ class LoaderCheckPoint: ...@@ -138,7 +140,7 @@ class LoaderCheckPoint:
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16, torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
trust_remote_code=True).half().to(self.llm_device) trust_remote_code=True).half().to(self.llm_device)
else: else:
from accelerate import dispatch_model,infer_auto_device_map from accelerate import dispatch_model, infer_auto_device_map
model = LoaderClass.from_pretrained(checkpoint, model = LoaderClass.from_pretrained(checkpoint,
config=self.model_config, config=self.model_config,
...@@ -146,10 +148,10 @@ class LoaderCheckPoint: ...@@ -146,10 +148,10 @@ class LoaderCheckPoint:
trust_remote_code=True).half() trust_remote_code=True).half()
# 可传入device_map自定义每张卡的部署情况 # 可传入device_map自定义每张卡的部署情况
if self.device_map is None: if self.device_map is None:
if 'chatglm' in model_name.lower(): if 'chatglm' in self.model_name.lower():
self.device_map = self.chatglm_auto_configure_device_map(num_gpus) self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in model_name.lower(): elif 'moss' in self.model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name) self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
else: else:
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败 # 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡 # 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
...@@ -166,9 +168,9 @@ class LoaderCheckPoint: ...@@ -166,9 +168,9 @@ class LoaderCheckPoint:
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map # 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错 # 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
# 实测在bloom模型上如此 # 实测在bloom模型上如此
# self.device_map = infer_auto_device_map(model, # self.device_map = infer_auto_device_map(model,
# dtype=torch.int8, # dtype=torch.int8,
# no_split_module_classes=model._no_split_modules) # no_split_module_classes=model._no_split_modules)
model = dispatch_model(model, device_map=self.device_map) model = dispatch_model(model, device_map=self.device_map)
else: else:
...@@ -202,7 +204,7 @@ class LoaderCheckPoint: ...@@ -202,7 +204,7 @@ class LoaderCheckPoint:
# tokenizer = model.tokenizer # tokenizer = model.tokenizer
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容 # todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
#* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用 # * -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
tokenizer = AutoTokenizer.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
return model, tokenizer return model, tokenizer
...@@ -231,7 +233,7 @@ class LoaderCheckPoint: ...@@ -231,7 +233,7 @@ class LoaderCheckPoint:
llm_int8_enable_fp32_cpu_offload=False) llm_int8_enable_fp32_cpu_offload=False)
with init_empty_weights(): with init_empty_weights():
model = LoaderClass.from_config(self.model_config,trust_remote_code = True) model = LoaderClass.from_config(self.model_config, trust_remote_code=True)
model.tie_weights() model.tie_weights()
if self.device_map is not None: if self.device_map is not None:
params['device_map'] = self.device_map params['device_map'] = self.device_map
...@@ -321,7 +323,7 @@ class LoaderCheckPoint: ...@@ -321,7 +323,7 @@ class LoaderCheckPoint:
return device_map return device_map
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]: def moss_auto_configure_device_map(self, num_gpus: int, checkpoint) -> Dict[str, int]:
try: try:
from accelerate import init_empty_weights from accelerate import init_empty_weights
...@@ -336,16 +338,6 @@ class LoaderCheckPoint: ...@@ -336,16 +338,6 @@ class LoaderCheckPoint:
"`pip install bitsandbytes``pip install accelerate`." "`pip install bitsandbytes``pip install accelerate`."
) from exc ) from exc
if self.model_path:
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
raise ValueError(
"本地模型local_model_path未配置路径"
)
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM", cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
pretrained_model_name_or_path=checkpoint) pretrained_model_name_or_path=checkpoint)
...@@ -452,7 +444,7 @@ class LoaderCheckPoint: ...@@ -452,7 +444,7 @@ class LoaderCheckPoint:
def reload_model(self): def reload_model(self):
self.unload_model() self.unload_model()
self.model_config = self._load_model_config(self.model_name) self.model_config = self._load_model_config()
if self.use_ptuning_v2: if self.use_ptuning_v2:
try: try:
...@@ -464,7 +456,7 @@ class LoaderCheckPoint: ...@@ -464,7 +456,7 @@ class LoaderCheckPoint:
except Exception as e: except Exception as e:
print("加载PrefixEncoder config.json失败") print("加载PrefixEncoder config.json失败")
self.model, self.tokenizer = self._load_model(self.model_name) self.model, self.tokenizer = self._load_model()
if self.lora: if self.lora:
self._add_lora_to_model([self.lora]) self._add_lora_to_model([self.lora])
......
from abc import ABC from abc import ABC
from langchain.llms.base import LLM from langchain.chains.base import Chain
from typing import Optional, List from typing import Any, Dict, List, Optional, Generator, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult) AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
import torch import torch
# todo 建议重写instruction,在该instruction下,各模型的表现比较差 # todo 建议重写instruction,在该instruction下,各模型的表现比较差
META_INSTRUCTION = \ META_INSTRUCTION = \
"""You are an AI assistant whose name is MOSS. """You are an AI assistant whose name is MOSS.
...@@ -20,41 +28,65 @@ META_INSTRUCTION = \ ...@@ -20,41 +28,65 @@ META_INSTRUCTION = \
Capabilities and tools that MOSS can possess. Capabilities and tools that MOSS can possess.
""" """
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因 # todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
class MOSSLLM(BaseAnswer, LLM, ABC): class MOSSLLMChain(BaseAnswer, Chain, ABC):
max_token: int = 2048 max_token: int = 2048
temperature: float = 0.7 temperature: float = 0.7
top_p = 0.8 top_p = 0.8
# history = [] # history = []
checkPoint: LoaderCheckPoint = None checkPoint: LoaderCheckPoint = None
history_len: int = 10 history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None): def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__() super().__init__()
self.checkPoint = checkPoint self.checkPoint = checkPoint
@property @property
def _llm_type(self) -> str: def _chain_type(self) -> str:
return "MOSS" return "MOSSLLMChain"
@property @property
def _check_point(self) -> LoaderCheckPoint: def input_keys(self) -> List[str]:
return self.checkPoint """Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property @property
def _history_len(self) -> int: def output_keys(self) -> List[str]:
"""Will always return text key.
return self.history_len :meta private:
"""
return [self.output_key]
def set_history_len(self, history_len: int) -> None: @property
self.history_len = history_len def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(
pass self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def generatorAnswer(self, prompt: str, def _generate_answer(self,
history: List[List[str]] = [], inputs: Dict[str, Any],
streaming: bool = False): run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
if len(history) > 0: if len(history) > 0:
history = history[-self.history_len:] if self.history_len > 0 else [] history = history[-self.history_len:] if self.history_len > 0 else []
prompt_w_history = str(history) prompt_w_history = str(history)
...@@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC): ...@@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
num_return_sequences=1, num_return_sequences=1,
eos_token_id=106068, eos_token_id=106068,
pad_token_id=self.checkPoint.tokenizer.pad_token_id) pad_token_id=self.checkPoint.tokenizer.pad_token_id)
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True)
self.checkPoint.clear_torch_cache() self.checkPoint.clear_torch_cache()
history += [[prompt, response]] history += [[prompt, response]]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
yield answer_result generate_with_callback(answer_result)
...@@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ ...@@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if use_ptuning_v2: if use_ptuning_v2:
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2 loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
# 如果指定了参数,则使用参数的配置
if llm_model: if llm_model:
llm_model_info = llm_model_dict[llm_model] llm_model_info = llm_model_dict[llm_model]
if loaderCheckPoint.no_remote_model:
loaderCheckPoint.model_name = llm_model_info['name'] loaderCheckPoint.model_name = llm_model_info['name']
else: loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.model_path = llm_model_info["local_model_path"] loaderCheckPoint.model_path = llm_model_info["local_model_path"]
......
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
import asyncio
from argparse import Namespace
from models.loader.args import parser
from models.loader import LoaderCheckPoint
import models.shared as shared
async def dispatch(args: Namespace):
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
history = [
("which city is this?", "tokyo"),
("why?", "she's japanese"),
]
for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
streaming=False):
resp = answer_result.llm_output["answer"]
print(resp)
if __name__ == '__main__':
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'fastchat-chatglm-6b', '--no-remote-model'])
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(dispatch(args))
...@@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR ...@@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query, yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], "" "请选择知识库后进行测试,当前未选择知识库。"]], ""
else: else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming): answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": query, "history": history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"] resp = answer_result.llm_output["answer"]
history = answer_result.history history = answer_result.history
history[-1][-1] = resp history[-1][-1] = resp
...@@ -101,11 +104,12 @@ def init_model(): ...@@ -101,11 +104,12 @@ def init_model():
args_dict = vars(args) args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict) shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM() llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try: try:
local_doc_qa.init_cfg(llm_model=llm_model_ins) local_doc_qa.init_cfg(llm_model=llm_model_ins)
generator = local_doc_qa.llm.generatorAnswer("你好") answer_result_stream_result = local_doc_qa.llm_model_chain(
for answer_result in generator: {"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output) print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply) logger.info(reply)
...@@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u ...@@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
filelist = [] filelist = []
if local_doc_qa.llm and local_doc_qa.embeddings: if local_doc_qa.llm_model_chain and local_doc_qa.embeddings:
if isinstance(files, list): if isinstance(files, list):
for file in files: for file in files:
filename = os.path.split(file.name)[-1] filename = os.path.split(file.name)[-1]
...@@ -165,7 +169,7 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte ...@@ -165,7 +169,7 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
def change_vs_name_input(vs_id, history): def change_vs_name_input(vs_id, history):
if vs_id == "新建知识库": if vs_id == "新建知识库":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history, \
gr.update(choices=[]), gr.update(visible=False) gr.update(choices=[]), gr.update(visible=False)
else: else:
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
...@@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history): ...@@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history):
def add_vs_name(vs_name, chatbot): def add_vs_name(vs_name, chatbot):
if vs_name is None or vs_name.strip() == "" : if vs_name is None or vs_name.strip() == "":
vs_status = "知识库名称不能为空,请重新填写知识库名称" vs_status = "知识库名称不能为空,请重新填写知识库名称"
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update( return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
...@@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history): ...@@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
def refresh_vs_list(): def refresh_vs_list():
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list()) return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
def delete_file(vs_id, files_to_delete, chatbot): def delete_file(vs_id, files_to_delete, chatbot):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content") content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
...@@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot): ...@@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
rested_files = local_doc_qa.list_file_from_vector_store(vs_path) rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
if "fail" in status: if "fail" in status:
vs_status = "文件删除失败。" vs_status = "文件删除失败。"
elif len(rested_files)>0: elif len(rested_files) > 0:
vs_status = "文件删除成功。" vs_status = "文件删除成功。"
else: else:
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。" vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
logger.info(",".join(files_to_delete)+vs_status) logger.info(",".join(files_to_delete) + vs_status)
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
...@@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot): ...@@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
status = f"成功删除知识库{vs_id}" status = f"成功删除知识库{vs_id}"
logger.info(status) logger.info(status)
chatbot = chatbot + [[None, status]] chatbot = chatbot + [[None, status]]
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \ return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(
visible=True), \
gr.update(visible=False), chatbot, gr.update(visible=False) gr.update(visible=False), chatbot, gr.update(visible=False)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
...@@ -333,7 +339,8 @@ default_theme_args = dict( ...@@ -333,7 +339,8 @@ default_theme_args = dict(
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo: with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
vs_path, file_status, model_status = gr.State( vs_path, file_status, model_status = gr.State(
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State( os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(
""), gr.State(
model_status) model_status)
gr.Markdown(webui_title) gr.Markdown(webui_title)
with gr.Tab("对话"): with gr.Tab("对话"):
......
...@@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR ...@@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query, yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], "" "请选择知识库后进行测试,当前未选择知识库。"]], ""
else: else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history, answer_result_stream_result = local_doc_qa.llm_model_chain(
streaming=streaming): {"prompt": query, "history": history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"] resp = answer_result.llm_output["answer"]
history = answer_result.history history = answer_result.history
history[-1][-1] = resp + ( history[-1][-1] = resp + (
...@@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec' ...@@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
args_dict.update(model=llm_model) args_dict.update(model=llm_model)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict) shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM() llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try: try:
local_doc_qa.init_cfg(llm_model=llm_model_ins, local_doc_qa.init_cfg(llm_model=llm_model_ins,
embedding_model=embedding_model) embedding_model=embedding_model)
generator = local_doc_qa.llm.generatorAnswer("你好") answer_result_stream_result = local_doc_qa.llm_model_chain(
for answer_result in generator: {"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output) print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply) logger.info(reply)
...@@ -468,7 +470,7 @@ with st.sidebar: ...@@ -468,7 +470,7 @@ with st.sidebar:
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K) top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
history_len = st.slider( history_len = st.slider(
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置 'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
local_doc_qa.llm.set_history_len(history_len) # local_doc_qa.llm.set_history_len(history_len)
chunk_conent = st.checkbox('启用上下文关联', False) chunk_conent = st.checkbox('启用上下文关联', False)
st.text('') st.text('')
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库 # chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论