提交 6ec1e56e 作者: hzg0601

Merge branch 'dev' of github.com:imClumsyPanda/langchain-ChatGLM into dev

...@@ -203,7 +203,7 @@ llm_model_dict = { ...@@ -203,7 +203,7 @@ llm_model_dict = {
} }
# LLM 名称 # LLM 名称
LLM_MODEL = "fastchat-chatglm-6b-int4" LLM_MODEL = "fastchat-chatglm"
# 量化加载8bit 模型 # 量化加载8bit 模型
LOAD_IN_8BIT = False LOAD_IN_8BIT = False
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
...@@ -220,7 +220,7 @@ STREAMING = True ...@@ -220,7 +220,7 @@ STREAMING = True
# Use p-tuning-v2 PrefixEncoder # Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False USE_PTUNING_V2 = False
PTUNING_DIR='./ptuing-v2' PTUNING_DIR='./ptuning-v2'
# LLM running device # LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
......
...@@ -2,14 +2,14 @@ from abc import ABC ...@@ -2,14 +2,14 @@ from abc import ABC
from langchain.chains.base import Chain from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor # from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult,
AnswerResultStream, AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue) AnswerResultQueueSentinelTokenListenerQueue)
import torch # import torch
import transformers import transformers
......
...@@ -44,7 +44,7 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th ...@@ -44,7 +44,7 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
parser.add_argument('--use-ptuning-v2',type=str,default=USE_PTUNING_V2,help="whether use ptuning-v2 checkpoint") parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint")
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
# Accelerate/transformers # Accelerate/transformers
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
......
...@@ -441,7 +441,7 @@ class LoaderCheckPoint: ...@@ -441,7 +441,7 @@ class LoaderCheckPoint:
if self.use_ptuning_v2: if self.use_ptuning_v2:
try: try:
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_config = json.loads(prefix_encoder_file.read())
prefix_encoder_file.close() prefix_encoder_file.close()
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
...@@ -457,13 +457,14 @@ class LoaderCheckPoint: ...@@ -457,13 +457,14 @@ class LoaderCheckPoint:
if self.use_ptuning_v2: if self.use_ptuning_v2:
try: try:
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
new_prefix_state_dict = {} new_prefix_state_dict = {}
for k, v in prefix_state_dict.items(): for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."): if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float() self.model.transformer.prefix_encoder.float()
print("加载ptuning检查点成功!")
except Exception as e: except Exception as e:
print(e) print(e)
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论