Unverified 提交 14d998b8 作者: shrimp 提交者: GitHub

可选择lora权重加载 (#231)

* Add files via upload

增加lora权重使用

* Update model_config.py

* Add files via upload

修复一个小错误,少写了模型加载

* 使用lora微调的权重

使用lora微调的权重

* Update model_config.py
上级 47922d2e
......@@ -61,9 +61,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
self, embedding: List[float], k: int = 4,
) -> List[Tuple[Document, float]]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = []
......@@ -122,12 +120,12 @@ class LocalDocQA:
llm_model: str = LLM_MODEL,
llm_device=LLM_DEVICE,
top_k=VECTOR_SEARCH_TOP_K,
use_ptuning_v2: bool = USE_PTUNING_V2
use_ptuning_v2: bool = USE_PTUNING_V2,
use_lora: bool = USE_LORA,
):
self.llm = ChatGLM()
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
llm_device=llm_device,
use_ptuning_v2=use_ptuning_v2)
llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora)
self.llm.history_len = llm_history_len
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
......
......@@ -27,6 +27,11 @@ llm_model_dict = {
# LLM model name
LLM_MODEL = "chatglm-6b"
# LLM lora path,默认为空,如果有请直接指定文件夹路径
# 推荐使用 chatglm-6b-belle-zh-lora
LLM_LORA_PATH = ""
USE_LORA = True if LLM_LORA_PATH else False
# LLM streaming reponse
STREAMING = True
......
......@@ -106,6 +106,7 @@ class ChatGLM(LLM):
model_name_or_path: str = "THUDM/chatglm-6b",
llm_device=LLM_DEVICE,
use_ptuning_v2=False,
use_lora=False,
device_map: Optional[Dict[str, int]] = None,
**kwargs):
self.tokenizer = AutoTokenizer.from_pretrained(
......@@ -125,45 +126,32 @@ class ChatGLM(LLM):
except Exception as e:
print(e)
print("加载PrefixEncoder config.json失败")
self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True,
**kwargs)
if LLM_LORA_PATH and use_lora:
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
# 根据当前设备GPU数量决定是否进行多卡部署
num_gpus = torch.cuda.device_count()
if num_gpus < 2 and device_map is None:
self.model = (
AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
trust_remote_code=True,
**kwargs)
.half()
.cuda()
)
self.model = self.model.half().cuda()
else:
from accelerate import dispatch_model
model = (
AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=True,
config=model_config,
**kwargs)
.half())
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
config=model_config, **kwargs)
if LLM_LORA_PATH and use_lora:
from peft import PeftModel
model_auto = PeftModel.from_pretrained(model, LLM_LORA_PATH)
# 可传入device_map自定义每张卡的部署情况
if device_map is None:
device_map = auto_configure_device_map(num_gpus)
self.model = dispatch_model(model, device_map=device_map)
self.model = dispatch_model(model_auto.half(), device_map=device_map)
else:
self.model = (
AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
trust_remote_code=True,
**kwargs)
.float()
.to(llm_device)
)
self.model = self.model.float().to(llm_device)
if use_ptuning_v2:
try:
......@@ -185,7 +173,7 @@ if __name__ == "__main__":
llm = ChatGLM()
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
llm_device=LLM_DEVICE, )
last_print_len=0
last_print_len = 0
for resp, history in llm._call("你好", streaming=True):
print(resp[last_print_len:], end="", flush=True)
last_print_len = len(resp)
......
......@@ -12,4 +12,5 @@ accelerate
gradio==3.24.1
fastapi
uvicorn
peft
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
......@@ -72,12 +72,13 @@ def init_model():
return reply
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history):
try:
local_doc_qa.init_cfg(llm_model=llm_model,
embedding_model=embedding_model,
llm_history_len=llm_history_len,
use_ptuning_v2=use_ptuning_v2,
use_lora = use_lora,
top_k=top_k,)
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
print(model_status)
......@@ -246,6 +247,9 @@ with gr.Blocks(css=block_css) as demo:
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
label="使用p-tuning-v2微调过的模型",
interactive=True)
use_lora = gr.Checkbox(USE_LORA,
label="使用lora微调的权重",
interactive=True)
embedding_model = gr.Radio(embedding_model_dict_list,
label="Embedding 模型",
value=EMBEDDING_MODEL,
......@@ -259,7 +263,7 @@ with gr.Blocks(css=block_css) as demo:
load_model_button = gr.Button("重新加载模型")
load_model_button.click(reinit_model,
show_progress=True,
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot],
outputs=chatbot
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论