提交 0abd2d99 作者: glide-the

llama_llm.py 提示词修改

上级 f1cfd6d6
...@@ -74,7 +74,7 @@ llm_model_dict = { ...@@ -74,7 +74,7 @@ llm_model_dict = {
"vicuna-13b-hf": { "vicuna-13b-hf": {
"name": "vicuna-13b-hf", "name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf",
"local_model_path": "/media/checkpoint/vicuna-13b-hf", "local_model_path": None,
"provides": "LLamaLLM" "provides": "LLamaLLM"
}, },
......
...@@ -98,9 +98,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -98,9 +98,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
""" """
formatted_history = '' formatted_history = ''
history = history[-self.history_len:] if self.history_len > 0 else [] history = history[-self.history_len:] if self.history_len > 0 else []
for i, (old_query, response) in enumerate(history): if len(history) > 0:
formatted_history += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) for i, (old_query, response) in enumerate(history):
formatted_history += "[Round {}]\n问:{}\n答:".format(len(history), query) formatted_history += "### Human:{}\n### Assistant:{}\n".format(old_query, response)
formatted_history += "### Human:{}\n### Assistant:".format(query)
return formatted_history return formatted_history
def prepare_inputs_for_generation(self, def prepare_inputs_for_generation(self,
...@@ -140,12 +141,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -140,12 +141,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
"max_new_tokens": self.max_new_tokens, "max_new_tokens": self.max_new_tokens,
"num_beams": self.num_beams, "num_beams": self.num_beams,
"top_p": self.top_p, "top_p": self.top_p,
"do_sample": True,
"top_k": self.top_k, "top_k": self.top_k,
"repetition_penalty": self.repetition_penalty, "repetition_penalty": self.repetition_penalty,
"encoder_repetition_penalty": self.encoder_repetition_penalty, "encoder_repetition_penalty": self.encoder_repetition_penalty,
"min_length": self.min_length, "min_length": self.min_length,
"temperature": self.temperature, "temperature": self.temperature,
"eos_token_id": self.eos_token_id, "eos_token_id": self.checkPoint.tokenizer.eos_token_id,
"logits_processor": self.logits_processor} "logits_processor": self.logits_processor}
# 向量转换 # 向量转换
...@@ -178,6 +180,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -178,6 +180,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
response = self._call(prompt=softprompt, stop=['\n###']) response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history + [[None, response]] answer_result.history = history + [[prompt, response]]
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
yield answer_result yield answer_result
...@@ -75,8 +75,8 @@ class MOSSLLM(BaseAnswer, LLM, ABC): ...@@ -75,8 +75,8 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
repetition_penalty=1.02, repetition_penalty=1.02,
num_return_sequences=1, num_return_sequences=1,
eos_token_id=106068, eos_token_id=106068,
pad_token_id=self.tokenizer.pad_token_id) pad_token_id=self.checkPoint.tokenizer.pad_token_id)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
self.checkPoint.clear_torch_cache() self.checkPoint.clear_torch_cache()
history += [[prompt, response]] history += [[prompt, response]]
answer_result = AnswerResult() answer_result = AnswerResult()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论