Unverified 提交 acee2d5a 作者: Zhi-guo Huang 提交者: GitHub

Merge pull request #905 from chinainfant/dev

解决加载ptuning检查点失败的问题
...@@ -44,7 +44,7 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th ...@@ -44,7 +44,7 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
parser.add_argument('--use-ptuning-v2',type=str,default=USE_PTUNING_V2,help="whether use ptuning-v2 checkpoint") parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint")
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
# Accelerate/transformers # Accelerate/transformers
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
......
...@@ -441,7 +441,7 @@ class LoaderCheckPoint: ...@@ -441,7 +441,7 @@ class LoaderCheckPoint:
if self.use_ptuning_v2: if self.use_ptuning_v2:
try: try:
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_config = json.loads(prefix_encoder_file.read())
prefix_encoder_file.close() prefix_encoder_file.close()
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
...@@ -457,13 +457,14 @@ class LoaderCheckPoint: ...@@ -457,13 +457,14 @@ class LoaderCheckPoint:
if self.use_ptuning_v2: if self.use_ptuning_v2:
try: try:
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
new_prefix_state_dict = {} new_prefix_state_dict = {}
for k, v in prefix_state_dict.items(): for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."): if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float() self.model.transformer.prefix_encoder.float()
print("加载ptuning检查点成功!")
except Exception as e: except Exception as e:
print(e) print(e)
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论