提交 09248391 作者: imClumsyPanda

Merge remote-tracking branch 'origin/dev' into dev

......@@ -130,13 +130,13 @@ class LoaderCheckPoint:
# 可传入device_map自定义每张卡的部署情况
if self.device_map is None:
if 'chatglm' in model_name.lower():
device_map = self.chatglm_auto_configure_device_map(num_gpus)
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in model_name.lower():
device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
else:
device_map = self.chatglm_auto_configure_device_map(num_gpus)
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
model = dispatch_model(model, device_map=device_map)
model = dispatch_model(model, device_map=self.device_map)
else:
print(
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论