提交 c4ee36b8 作者: glide-the

删除 AnswerResultStream 、generate_with_callback收集器

上级 e7b06a90
...@@ -12,9 +12,7 @@ from tqdm import tqdm ...@@ -12,9 +12,7 @@ from tqdm import tqdm
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from models.loader.args import parser from models.loader.args import parser
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
import models.shared as shared import models.shared as shared
......
...@@ -10,142 +10,12 @@ import transformers ...@@ -10,142 +10,12 @@ import transformers
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
class ListenerToken:
"""
观测结果
"""
input_ids: torch.LongTensor
_scores: torch.FloatTensor
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
self.input_ids = input_ids
self._scores = _scores
class AnswerResult: class AnswerResult:
""" """
消息实体 消息实体
""" """
history: List[List[str]] = [] history: List[List[str]] = []
llm_output: Optional[dict] = None llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream:
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, answerResult: AnswerResult):
if self.callback_func is not None:
self.callback_func(answerResult)
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
"""
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
"""
listenerQueue: deque = deque(maxlen=1)
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}):
self.mfunc = func
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
结束条件包含如下
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
3、模型预测出错
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象,
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
2、消息为self.sentinel标识,抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
异步线程检测stop_now属性被更新,抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if self.stop_now:
raise ValueError
self.q.put(val)
def gen():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
self.thread = Thread(target=gen)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
"""
暂无实现
:return:
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
""" break 后会执行 """
self.stop_now = True
class BaseAnswer(ABC): class BaseAnswer(ABC):
...@@ -168,22 +38,4 @@ class BaseAnswer(ABC): ...@@ -168,22 +38,4 @@ class BaseAnswer(ABC):
def generatorAnswer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False): streaming: bool = False):
def generate_with_callback(callback=None, **kwargs):
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
self._generate_answer(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs)
with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
for answerResult in generator:
if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult
@abstractmethod
def _generate_answer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False,
generate_with_callback: AnswerResultStream = None) -> None:
pass pass
...@@ -5,9 +5,7 @@ from langchain.llms.base import LLM ...@@ -5,9 +5,7 @@ from langchain.llms.base import LLM
from typing import Optional, List from typing import Optional, List
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import transformers import transformers
...@@ -43,15 +41,9 @@ class ChatGLM(BaseAnswer, LLM, ABC): ...@@ -43,15 +41,9 @@ class ChatGLM(BaseAnswer, LLM, ABC):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass pass
def _generate_answer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False):
generate_with_callback: AnswerResultStream = None) -> None:
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
stopping_criteria_list.append(listenerQueue)
if streaming: if streaming:
history += [[]] history += [[]]
...@@ -60,34 +52,27 @@ class ChatGLM(BaseAnswer, LLM, ABC): ...@@ -60,34 +52,27 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt, prompt,
history=history[-self.history_len:-1] if self.history_len > 0 else [], history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature
stopping_criteria=stopping_criteria_list
)): )):
# self.checkPoint.clear_torch_cache() # self.checkPoint.clear_torch_cache()
history[-1] = [prompt, stream_resp] history[-1] = [prompt, stream_resp]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": stream_resp} answer_result.llm_output = {"answer": stream_resp}
if listenerQueue.listenerQueue.__len__() > 0: yield answer_result
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
else: else:
response, _ = self.checkPoint.model.chat( response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer, self.checkPoint.tokenizer,
prompt, prompt,
history=history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature
stopping_criteria=stopping_criteria_list
) )
self.checkPoint.clear_torch_cache() self.checkPoint.clear_torch_cache()
history += [[prompt, response]] history += [[prompt, response]]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
if listenerQueue.listenerQueue.__len__() > 0: yield answer_result
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
...@@ -5,9 +5,7 @@ from langchain.llms.base import LLM ...@@ -5,9 +5,7 @@ from langchain.llms.base import LLM
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
class FastChatLLM(BaseAnswer, LLM, ABC): class FastChatLLM(BaseAnswer, LLM, ABC):
...@@ -40,10 +38,9 @@ class FastChatLLM(BaseAnswer, LLM, ABC): ...@@ -40,10 +38,9 @@ class FastChatLLM(BaseAnswer, LLM, ABC):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass pass
def _generate_answer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False):
generate_with_callback: AnswerResultStream = None) -> None:
response = "fastchat 响应结果" response = "fastchat 响应结果"
history += [[prompt, response]] history += [[prompt, response]]
...@@ -51,4 +48,4 @@ class FastChatLLM(BaseAnswer, LLM, ABC): ...@@ -51,4 +48,4 @@ class FastChatLLM(BaseAnswer, LLM, ABC):
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
generate_with_callback(answer_result) yield answer_result
...@@ -9,9 +9,7 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL ...@@ -9,9 +9,7 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
class InvalidScoreLogitsProcessor(LogitsProcessor): class InvalidScoreLogitsProcessor(LogitsProcessor):
...@@ -178,23 +176,15 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -178,23 +176,15 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
self.history = self.history + [[None, reply]] self.history = self.history + [[None, reply]]
return reply return reply
def _generate_answer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False):
generate_with_callback: AnswerResultStream = None) -> None:
if history: if history:
self.history = history self.history = history
# Create the StoppingCriteriaList with the stopping strings
self.stopping_criteria = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
self.stopping_criteria.append(listenerQueue)
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现 # TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
softprompt = self.generate_softprompt_history_tensors(prompt) softprompt = self.generate_softprompt_history_tensors(prompt)
response = self._call(prompt=softprompt, stop=['\n###']) response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = self.history answer_result.history = self.history
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
generate_with_callback(answer_result) yield answer_result
...@@ -3,9 +3,7 @@ from langchain.llms.base import LLM ...@@ -3,9 +3,7 @@ from langchain.llms.base import LLM
from typing import Optional, List from typing import Optional, List
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch import torch
...@@ -53,10 +51,9 @@ class MOSSLLM(BaseAnswer, LLM, ABC): ...@@ -53,10 +51,9 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass pass
def _generate_answer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False):
generate_with_callback: AnswerResultStream = None) -> None:
if len(history) > 0: if len(history) > 0:
history = history[-self.history_len:-1] if self.history_len > 0 else [] history = history[-self.history_len:-1] if self.history_len > 0 else []
prompt_w_history = str(history) prompt_w_history = str(history)
...@@ -86,6 +83,6 @@ class MOSSLLM(BaseAnswer, LLM, ABC): ...@@ -86,6 +83,6 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
generate_with_callback(answer_result) yield answer_result
...@@ -6,9 +6,7 @@ from chains.local_doc_qa import LocalDocQA ...@@ -6,9 +6,7 @@ from chains.local_doc_qa import LocalDocQA
from configs.model_config import * from configs.model_config import *
import nltk import nltk
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import models.shared as shared import models.shared as shared
from models.loader.args import parser from models.loader.args import parser
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论