提交 3b4b660d 作者: imClumsyPanda

update chatglm_llm.py

上级 a1033698
...@@ -68,19 +68,33 @@ class ChatGLM(LLM): ...@@ -68,19 +68,33 @@ class ChatGLM(LLM):
def _call(self, def _call(self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None) -> str: stop: Optional[List[str]] = None,
response, _ = self.model.chat( stream=True) -> str:
self.tokenizer, if stream:
prompt, self.history = self.history + [[None, ""]]
history=self.history[-self.history_len:] if self.history_len > 0 else [], response, _ = self.model.stream_chat(
max_length=self.max_token, self.tokenizer,
temperature=self.temperature, prompt,
) history=self.history[-self.history_len:] if self.history_len > 0 else [],
torch_gc() max_length=self.max_token,
if stop is not None: temperature=self.temperature,
response = enforce_stop_tokens(response, stop) )
self.history = self.history + [[None, response]] torch_gc()
return response self.history[-1][-1] = response
yield response
else:
response, _ = self.model.chat(
self.tokenizer,
prompt,
history=self.history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature,
)
torch_gc()
if stop is not None:
response = enforce_stop_tokens(response, stop)
self.history = self.history + [[None, response]]
return response
def chat(self, def chat(self,
prompt: str) -> str: prompt: str) -> str:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论