Unverified 提交 e679136b 作者: imClumsyPanda 提交者: GitHub

Merge pull request #104 from thaumstrial/master

Support for p-tuningv2
......@@ -24,6 +24,9 @@ llm_model_dict = {
# LLM model name
LLM_MODEL = "chatglm-6b"
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False
# LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
import json
import os
from langchain.llms.base import LLM
from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from configs.model_config import LLM_DEVICE
......@@ -51,15 +54,30 @@ class ChatGLM(LLM):
def load_model(self,
model_name_or_path: str = "THUDM/chatglm-6b",
llm_device=LLM_DEVICE):
llm_device=LLM_DEVICE,
use_ptuning_v2=False):
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if use_ptuning_v2:
try:
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
prefix_encoder_config = json.loads(prefix_encoder_file.read())
prefix_encoder_file.close()
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
except Exception:
print("加载PrefixEncoder config.json失败")
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
self.model = (
AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
trust_remote_code=True)
.half()
.cuda()
......@@ -68,8 +86,22 @@ class ChatGLM(LLM):
self.model = (
AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
trust_remote_code=True)
.float()
.to(llm_device)
)
if use_ptuning_v2:
try:
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
except Exception:
print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval()
如果使用了[p-tuning-v2](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)方式微调了模型,可以将得到的PrefixEndoer放入此文件夹。
只需要放入模型的*config.json**pytorch_model.bin*
并在加载模型时勾选 *"使用p-tuning-v2微调过的模型"*
\ No newline at end of file
......@@ -53,11 +53,12 @@ def init_model():
return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
def reinit_model(llm_model, embedding_model, llm_history_len, top_k, history):
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
try:
local_doc_qa.init_cfg(llm_model=llm_model,
embedding_model=embedding_model,
llm_history_len=llm_history_len,
use_ptuning_v2=use_ptuning_v2,
top_k=top_k)
model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
except:
......@@ -97,7 +98,7 @@ webui_title = """
"""
init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示
1. 选择语言模型、Embedding 模型及相关参数,如果使用ptuning-v2方式微调过模型,将PrefixEncoder模型放在ptuning-v2文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示
2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
3. 输入要提交的问题后,点击回车提交 """
......@@ -127,6 +128,9 @@ with gr.Blocks(css=block_css) as demo:
step=1,
label="LLM history len",
interactive=True)
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
label="使用p-tuning-v2微调过的模型",
interactive=True)
embedding_model = gr.Radio(embedding_model_dict_list,
label="Embedding 模型",
value=EMBEDDING_MODEL,
......@@ -152,7 +156,7 @@ with gr.Blocks(css=block_css) as demo:
load_file_button = gr.Button("加载文件")
load_model_button.click(reinit_model,
show_progress=True,
inputs=[llm_model, embedding_model, llm_history_len, top_k, chatbot],
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
outputs=chatbot
)
# 将上传的文件保存到content文件夹下,并更新下拉框
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论