提交 00d80335 作者: imClumsyPanda

update loader.py

上级 5524c476
from .chatglm_llm import ChatGLM from .chatglm_llm import ChatGLM
from .llama_llm import LLamaLLM # from .llama_llm import LLamaLLM
from .moss_llm import MOSSLLM from .moss_llm import MOSSLLM
...@@ -36,10 +36,6 @@ class LoaderCheckPoint: ...@@ -36,10 +36,6 @@ class LoaderCheckPoint:
lora_dir: str = None lora_dir: str = None
ptuning_dir: str = None ptuning_dir: str = None
use_ptuning_v2: bool = False use_ptuning_v2: bool = False
cpu: bool = False
gpu_memory: object = None
cpu_memory: object = None
auto_devices: object = True
# 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156 # 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156
load_in_8bit: bool = False load_in_8bit: bool = False
is_llamacpp: bool = False is_llamacpp: bool = False
...@@ -56,20 +52,16 @@ class LoaderCheckPoint: ...@@ -56,20 +52,16 @@ class LoaderCheckPoint:
:param params: :param params:
""" """
self.model_path = None self.model_path = None
self.model = None
self.tokenizer = None
self.params = params or {} self.params = params or {}
self.no_remote_model = params.get('no_remote_model', False) self.no_remote_model = params.get('no_remote_model', False)
self.model_name = params.get('model', '') self.model_name = params.get('model', '')
self.lora = params.get('lora', '') self.lora = params.get('lora', '')
self.use_ptuning_v2 = params.get('use_ptuning_v2', False) self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
self.model = None
self.tokenizer = None
self.model_dir = params.get('model_dir', '') self.model_dir = params.get('model_dir', '')
self.lora_dir = params.get('lora_dir', '') self.lora_dir = params.get('lora_dir', '')
self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2') self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2')
self.cpu = params.get('cpu', False)
self.gpu_memory = params.get('gpu_memory', None)
self.cpu_memory = params.get('cpu_memory', None)
self.auto_devices = params.get('auto_devices', True)
self.load_in_8bit = params.get('load_in_8bit', False) self.load_in_8bit = params.get('load_in_8bit', False)
self.bf16 = params.get('bf16', False) self.bf16 = params.get('bf16', False)
...@@ -111,8 +103,8 @@ class LoaderCheckPoint: ...@@ -111,8 +103,8 @@ class LoaderCheckPoint:
LoaderClass = AutoModelForCausalLM LoaderClass = AutoModelForCausalLM
# Load the model in simple 16-bit mode by default # Load the model in simple 16-bit mode by default
if not any([self.cpu, self.load_in_8bit, self.auto_devices, self.gpu_memory is not None, if not any([self.llm_device.lower()=="cpu",
self.cpu_memory is not None, self.is_llamacpp]): self.load_in_8bit, self.is_llamacpp]):
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"): if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
# 根据当前设备GPU数量决定是否进行多卡部署 # 根据当前设备GPU数量决定是否进行多卡部署
...@@ -140,14 +132,15 @@ class LoaderCheckPoint: ...@@ -140,14 +132,15 @@ class LoaderCheckPoint:
if 'chatglm' in model_name.lower(): if 'chatglm' in model_name.lower():
device_map = self.chatglm_auto_configure_device_map(num_gpus) device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in model_name.lower(): elif 'moss' in model_name.lower():
device_map = self.moss_auto_configure_device_map(num_gpus,model_name) device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
else: else:
device_map = self.chatglm_auto_configure_device_map(num_gpus) device_map = self.chatglm_auto_configure_device_map(num_gpus)
model = dispatch_model(model, device_map=device_map) model = dispatch_model(model, device_map=device_map)
else: else:
print( print(
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n") "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "
"detected.\nFalling back to CPU mode.\n")
model = ( model = (
AutoModel.from_pretrained( AutoModel.from_pretrained(
checkpoint, checkpoint,
...@@ -169,47 +162,20 @@ class LoaderCheckPoint: ...@@ -169,47 +162,20 @@ class LoaderCheckPoint:
# Custom # Custom
else: else:
params = {"low_cpu_mem_usage": True} params = {"low_cpu_mem_usage": True}
if not any((self.cpu, torch.cuda.is_available(), torch.has_mps)):
print(
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
self.cpu = True
if self.cpu: if not self.llm_device.lower().startswith("cuda"):
params["torch_dtype"] = torch.float32 raise SystemError("8bit 模型需要 CUDA 支持,或者改用量化后模型!")
else: else:
params["device_map"] = 'auto' params["device_map"] = 'auto'
params["trust_remote_code"] = True params["trust_remote_code"] = True
if self.load_in_8bit and any((self.auto_devices, self.gpu_memory)): if self.load_in_8bit:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True) llm_int8_enable_fp32_cpu_offload=False)
elif self.load_in_8bit:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
elif self.bf16: elif self.bf16:
params["torch_dtype"] = torch.bfloat16 params["torch_dtype"] = torch.bfloat16
else: else:
params["torch_dtype"] = torch.float16 params["torch_dtype"] = torch.float16
if self.gpu_memory:
memory_map = list(map(lambda x: x.strip(), self.gpu_memory))
max_cpu_memory = self.cpu_memory.strip() if self.cpu_memory is not None else '99GiB'
max_memory = {}
for i in range(len(memory_map)):
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else \
memory_map[i]
max_memory['cpu'] = max_cpu_memory
params['max_memory'] = max_memory
elif self.auto_devices:
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
suggestion = round((total_mem - 1000) / 1000) * 1000
if total_mem - suggestion < 800:
suggestion -= 1000
suggestion = int(round(suggestion / 1000))
print(
f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{self.cpu_memory or 99}GiB'}
params['max_memory'] = max_memory
if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint) config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights(): with init_empty_weights():
...@@ -236,7 +202,8 @@ class LoaderCheckPoint: ...@@ -236,7 +202,8 @@ class LoaderCheckPoint:
tokenizer.eos_token_id = 2 tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1 tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0 tokenizer.pad_token_id = 0
except: except Exception as e:
print(e)
pass pass
else: else:
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
...@@ -331,7 +298,7 @@ class LoaderCheckPoint: ...@@ -331,7 +298,7 @@ class LoaderCheckPoint:
if len(lora_names) > 0: if len(lora_names) > 0:
print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names))) print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names)))
params = {} params = {}
if not self.cpu: if self.llm_device.lower() != "cpu":
params['dtype'] = self.model.dtype params['dtype'] = self.model.dtype
if hasattr(self.model, "hf_device_map"): if hasattr(self.model, "hf_device_map"):
params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()} params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()}
...@@ -344,7 +311,7 @@ class LoaderCheckPoint: ...@@ -344,7 +311,7 @@ class LoaderCheckPoint:
for lora in lora_names[1:]: for lora in lora_names[1:]:
self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora) self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
if not self.load_in_8bit and not self.cpu: if not self.load_in_8bit and self.llm_device.lower() != "cpu":
if not hasattr(self.model, "hf_device_map"): if not hasattr(self.model, "hf_device_map"):
if torch.has_mps: if torch.has_mps:
...@@ -355,12 +322,23 @@ class LoaderCheckPoint: ...@@ -355,12 +322,23 @@ class LoaderCheckPoint:
def clear_torch_cache(self): def clear_torch_cache(self):
gc.collect() gc.collect()
if not self.cpu: if self.llm_device.lower() != "cpu":
device_id = "0" if torch.cuda.is_available() else None if torch.has_mps:
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device try:
with torch.cuda.device(CUDA_DEVICE): from torch.mps import empty_cache
torch.cuda.empty_cache() empty_cache()
torch.cuda.ipc_collect() except Exception as e:
print(e)
print(
"如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
elif torch.has_cuda:
device_id = "0" if torch.cuda.is_available() else None
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
print("未检测到 cuda 或 mps,暂不支持清理显存")
def unload_model(self): def unload_model(self):
del self.model del self.model
...@@ -382,7 +360,7 @@ class LoaderCheckPoint: ...@@ -382,7 +360,7 @@ class LoaderCheckPoint:
prefix_encoder_file.close() prefix_encoder_file.close()
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection'] self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
except Exception: except Exception as e:
print("加载PrefixEncoder config.json失败") print("加载PrefixEncoder config.json失败")
self.model, self.tokenizer = self._load_model(self.model_name) self.model, self.tokenizer = self._load_model(self.model_name)
...@@ -399,7 +377,7 @@ class LoaderCheckPoint: ...@@ -399,7 +377,7 @@ class LoaderCheckPoint:
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float() self.model.transformer.prefix_encoder.float()
except Exception: except Exception as e:
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval() self.model = self.model.eval()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论