提交 218aca2e 作者: glide-the

删除model_dir和NO_REMOTE_MODEL

上级 f1f742ce
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -91,14 +91,10 @@ llm_model_dict = { ...@@ -91,14 +91,10 @@ llm_model_dict = {
# LLM 名称 # LLM 名称
LLM_MODEL = "chatglm-6b" LLM_MODEL = "chatglm-6b"
# 如果你需要加载本地的model,指定这个参数 ` --no-remote-model`,或者下方参数修改为 `True`
NO_REMOTE_MODEL = False
# 量化加载8bit 模型 # 量化加载8bit 模型
LOAD_IN_8BIT = False LOAD_IN_8BIT = False
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
BF16 = False BF16 = False
# 本地模型存放的位置
MODEL_DIR = "model/"
# 本地lora存放的位置 # 本地lora存放的位置
LORA_DIR = "loras/" LORA_DIR = "loras/"
......
...@@ -35,14 +35,13 @@ parser = argparse.ArgumentParser(prog='langchina-ChatGLM', ...@@ -35,14 +35,13 @@ parser = argparse.ArgumentParser(prog='langchina-ChatGLM',
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain | ' description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain | '
'基于本地知识库的 ChatGLM 问答') '基于本地知识库的 ChatGLM 问答')
parser.add_argument('--no-remote-model', action='store_true', default=NO_REMOTE_MODEL, help='remote in the model on ' parser.add_argument('--no-remote-model', action='store_true', help='remote in the model on '
'loader checkpoint, ' 'loader checkpoint, '
'if your load local ' 'if your load local '
'model to add the ` ' 'model to add the ` '
'--no-remote-model`') '--no-remote-model`')
parser.add_argument('--model', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--model-dir", type=str, default=MODEL_DIR, help="Path to directory with all the models")
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
# Accelerate/transformers # Accelerate/transformers
......
...@@ -26,7 +26,6 @@ class LoaderCheckPoint: ...@@ -26,7 +26,6 @@ class LoaderCheckPoint:
model: object = None model: object = None
model_config: object = None model_config: object = None
lora_names: set = [] lora_names: set = []
model_dir: str = None
lora_dir: str = None lora_dir: str = None
ptuning_dir: str = None ptuning_dir: str = None
use_ptuning_v2: bool = False use_ptuning_v2: bool = False
...@@ -45,28 +44,30 @@ class LoaderCheckPoint: ...@@ -45,28 +44,30 @@ class LoaderCheckPoint:
模型初始化 模型初始化
:param params: :param params:
""" """
self.model_path = None
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self.params = params or {} self.params = params or {}
self.model_name = params.get('model_name', False)
self.model_path = params.get('model_path', None)
self.no_remote_model = params.get('no_remote_model', False) self.no_remote_model = params.get('no_remote_model', False)
self.model_name = params.get('model', '')
self.lora = params.get('lora', '') self.lora = params.get('lora', '')
self.use_ptuning_v2 = params.get('use_ptuning_v2', False) self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
self.model_dir = params.get('model_dir', '')
self.lora_dir = params.get('lora_dir', '') self.lora_dir = params.get('lora_dir', '')
self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2') self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2')
self.load_in_8bit = params.get('load_in_8bit', False) self.load_in_8bit = params.get('load_in_8bit', False)
self.bf16 = params.get('bf16', False) self.bf16 = params.get('bf16', False)
def _load_model_config(self, model_name): def _load_model_config(self, model_name):
checkpoint = Path(f'{self.model_dir}/{model_name}')
if self.model_path: if self.model_path:
checkpoint = Path(f'{self.model_path}') checkpoint = Path(f'{self.model_path}')
else: else:
if not self.no_remote_model: if not self.no_remote_model:
checkpoint = model_name checkpoint = model_name
else:
raise ValueError(
"本地模型local_model_path未配置路径"
)
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
...@@ -81,16 +82,17 @@ class LoaderCheckPoint: ...@@ -81,16 +82,17 @@ class LoaderCheckPoint:
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
checkpoint = Path(f'{self.model_dir}/{model_name}')
self.is_llamacpp = len(list(checkpoint.glob('ggml*.bin'))) > 0
if self.model_path: if self.model_path:
checkpoint = Path(f'{self.model_path}') checkpoint = Path(f'{self.model_path}')
else: else:
if not self.no_remote_model: if not self.no_remote_model:
checkpoint = model_name checkpoint = model_name
else:
raise ValueError(
"本地模型local_model_path未配置路径"
)
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
if 'chatglm' in model_name.lower(): if 'chatglm' in model_name.lower():
LoaderClass = AutoModel LoaderClass = AutoModel
else: else:
...@@ -274,13 +276,16 @@ class LoaderCheckPoint: ...@@ -274,13 +276,16 @@ class LoaderCheckPoint:
"`pip install bitsandbytes``pip install accelerate`." "`pip install bitsandbytes``pip install accelerate`."
) from exc ) from exc
checkpoint = Path(f'{self.model_dir}/{model_name}')
if self.model_path: if self.model_path:
checkpoint = Path(f'{self.model_path}') checkpoint = Path(f'{self.model_path}')
else: else:
if not self.no_remote_model: if not self.no_remote_model:
checkpoint = model_name checkpoint = model_name
else:
raise ValueError(
"本地模型local_model_path未配置路径"
)
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM", cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
pretrained_model_name_or_path=checkpoint) pretrained_model_name_or_path=checkpoint)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论