提交 54c983f4 作者: imClumsyPanda

update chatglm_llm.py

上级 2224dece
...@@ -72,29 +72,29 @@ class ChatGLM(LLM): ...@@ -72,29 +72,29 @@ class ChatGLM(LLM):
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=self.history[-self.history_len:] if self.history_len>0 else [], history=self.history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) )
torch_gc() torch_gc()
if stop is not None: if stop is not None:
response = enforce_stop_tokens(response, stop) response = enforce_stop_tokens(response, stop)
self.history = self.history+[[None, response]] self.history = self.history + [[None, response]]
return response return response
def chat(self, def chat(self,
prompt: str) -> str: prompt: str) -> str:
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=[],#self.history[-self.history_len:] if self.history_len>0 else history=self.history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) )
torch_gc() torch_gc()
self.history = self.history+[[None, response]] self.history = self.history + [[None, response]]
return response return response
def load_model(self, def load_model(self,
model_name_or_path: str = "THUDM/chatglm-6b", model_name_or_path: str = "THUDM/chatglm-6b",
llm_device=LLM_DEVICE, llm_device=LLM_DEVICE,
...@@ -126,7 +126,7 @@ class ChatGLM(LLM): ...@@ -126,7 +126,7 @@ class ChatGLM(LLM):
AutoModel.from_pretrained( AutoModel.from_pretrained(
model_name_or_path, model_name_or_path,
config=model_config, config=model_config,
trust_remote_code=True, trust_remote_code=True,
**kwargs) **kwargs)
.half() .half()
.cuda() .cuda()
...@@ -159,7 +159,8 @@ class ChatGLM(LLM): ...@@ -159,7 +159,8 @@ class ChatGLM(LLM):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float() self.model.transformer.prefix_encoder.float()
except Exception: except Exception as e:
print(e)
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval() self.model = self.model.eval()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论