Unverified 提交 5524c476 作者: glide-the 提交者: GitHub

Update moss_llm.py

上级 7c749332
...@@ -58,6 +58,11 @@ class MOSSLLM(BaseAnswer, LLM, ABC): ...@@ -58,6 +58,11 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False,
generate_with_callback: AnswerResultStream = None) -> None: generate_with_callback: AnswerResultStream = None) -> None:
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
stopping_criteria_list.append(listenerQueue)
if len(history) > 0: if len(history) > 0:
history = history[-self.history_len:-1] if self.history_len > 0 else [] history = history[-self.history_len:-1] if self.history_len > 0 else []
prompt_w_history = str(history) prompt_w_history = str(history)
...@@ -83,6 +88,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC): ...@@ -83,6 +88,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
self.checkPoint.clear_torch_cache() self.checkPoint.clear_torch_cache()
history += [[prompt, response]] history += [[prompt, response]]
yield response, history answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论