提交 d5ffdaa2 作者: imClumsyPanda

update loader.py

上级 aa266454
...@@ -16,6 +16,7 @@ from transformers.modeling_utils import no_init_weights ...@@ -16,6 +16,7 @@ from transformers.modeling_utils import no_init_weights
from transformers.utils import ContextManagers from transformers.utils import ContextManagers
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import get_balanced_memory, infer_auto_device_map from accelerate.utils import get_balanced_memory, infer_auto_device_map
from configs.model_config import LLM_DEVICE
class LoaderCheckPoint: class LoaderCheckPoint:
...@@ -44,7 +45,7 @@ class LoaderCheckPoint: ...@@ -44,7 +45,7 @@ class LoaderCheckPoint:
# 自定义设备网络 # 自定义设备网络
device_map: Optional[Dict[str, int]] = None device_map: Optional[Dict[str, int]] = None
# 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu # 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu
llm_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" llm_device = LLM_DEVICE
def __init__(self, params: dict = None): def __init__(self, params: dict = None):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论