Unverified 提交 14d998b8 作者: shrimp 提交者: GitHub

可选择lora权重加载 (#231)

* Add files via upload

增加lora权重使用

* Update model_config.py

* Add files via upload

修复一个小错误,少写了模型加载

* 使用lora微调的权重

使用lora微调的权重

* Update model_config.py
上级 47922d2e
...@@ -61,9 +61,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]: ...@@ -61,9 +61,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self, embedding: List[float], k: int = 4,
embedding: List[float],
k: int = 4,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = [] docs = []
...@@ -122,12 +120,12 @@ class LocalDocQA: ...@@ -122,12 +120,12 @@ class LocalDocQA:
llm_model: str = LLM_MODEL, llm_model: str = LLM_MODEL,
llm_device=LLM_DEVICE, llm_device=LLM_DEVICE,
top_k=VECTOR_SEARCH_TOP_K, top_k=VECTOR_SEARCH_TOP_K,
use_ptuning_v2: bool = USE_PTUNING_V2 use_ptuning_v2: bool = USE_PTUNING_V2,
use_lora: bool = USE_LORA,
): ):
self.llm = ChatGLM() self.llm = ChatGLM()
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
llm_device=llm_device, llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora)
use_ptuning_v2=use_ptuning_v2)
self.llm.history_len = llm_history_len self.llm.history_len = llm_history_len
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
......
...@@ -27,6 +27,11 @@ llm_model_dict = { ...@@ -27,6 +27,11 @@ llm_model_dict = {
# LLM model name # LLM model name
LLM_MODEL = "chatglm-6b" LLM_MODEL = "chatglm-6b"
# LLM lora path,默认为空,如果有请直接指定文件夹路径
# 推荐使用 chatglm-6b-belle-zh-lora
LLM_LORA_PATH = ""
USE_LORA = True if LLM_LORA_PATH else False
# LLM streaming reponse # LLM streaming reponse
STREAMING = True STREAMING = True
......
...@@ -78,11 +78,11 @@ class ChatGLM(LLM): ...@@ -78,11 +78,11 @@ class ChatGLM(LLM):
torch_gc() torch_gc()
else: else:
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) )
torch_gc() torch_gc()
history += [[prompt, response]] history += [[prompt, response]]
...@@ -106,6 +106,7 @@ class ChatGLM(LLM): ...@@ -106,6 +106,7 @@ class ChatGLM(LLM):
model_name_or_path: str = "THUDM/chatglm-6b", model_name_or_path: str = "THUDM/chatglm-6b",
llm_device=LLM_DEVICE, llm_device=LLM_DEVICE,
use_ptuning_v2=False, use_ptuning_v2=False,
use_lora=False,
device_map: Optional[Dict[str, int]] = None, device_map: Optional[Dict[str, int]] = None,
**kwargs): **kwargs):
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
...@@ -125,45 +126,32 @@ class ChatGLM(LLM): ...@@ -125,45 +126,32 @@ class ChatGLM(LLM):
except Exception as e: except Exception as e:
print(e) print(e)
print("加载PrefixEncoder config.json失败") print("加载PrefixEncoder config.json失败")
self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True,
**kwargs)
if LLM_LORA_PATH and use_lora:
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"): if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
# 根据当前设备GPU数量决定是否进行多卡部署 # 根据当前设备GPU数量决定是否进行多卡部署
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
if num_gpus < 2 and device_map is None: if num_gpus < 2 and device_map is None:
self.model = ( self.model = self.model.half().cuda()
AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
trust_remote_code=True,
**kwargs)
.half()
.cuda()
)
else: else:
from accelerate import dispatch_model from accelerate import dispatch_model
model = ( model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
AutoModel.from_pretrained( config=model_config, **kwargs)
model_name_or_path, if LLM_LORA_PATH and use_lora:
trust_remote_code=True, from peft import PeftModel
config=model_config, model_auto = PeftModel.from_pretrained(model, LLM_LORA_PATH)
**kwargs)
.half())
# 可传入device_map自定义每张卡的部署情况 # 可传入device_map自定义每张卡的部署情况
if device_map is None: if device_map is None:
device_map = auto_configure_device_map(num_gpus) device_map = auto_configure_device_map(num_gpus)
self.model = dispatch_model(model, device_map=device_map) self.model = dispatch_model(model_auto.half(), device_map=device_map)
else: else:
self.model = ( self.model = self.model.float().to(llm_device)
AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
trust_remote_code=True,
**kwargs)
.float()
.to(llm_device)
)
if use_ptuning_v2: if use_ptuning_v2:
try: try:
...@@ -185,7 +173,7 @@ if __name__ == "__main__": ...@@ -185,7 +173,7 @@ if __name__ == "__main__":
llm = ChatGLM() llm = ChatGLM()
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL], llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
llm_device=LLM_DEVICE, ) llm_device=LLM_DEVICE, )
last_print_len=0 last_print_len = 0
for resp, history in llm._call("你好", streaming=True): for resp, history in llm._call("你好", streaming=True):
print(resp[last_print_len:], end="", flush=True) print(resp[last_print_len:], end="", flush=True)
last_print_len = len(resp) last_print_len = len(resp)
......
...@@ -12,4 +12,5 @@ accelerate ...@@ -12,4 +12,5 @@ accelerate
gradio==3.24.1 gradio==3.24.1
fastapi fastapi
uvicorn uvicorn
peft
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2 #detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
...@@ -72,12 +72,13 @@ def init_model(): ...@@ -72,12 +72,13 @@ def init_model():
return reply return reply
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history): def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history):
try: try:
local_doc_qa.init_cfg(llm_model=llm_model, local_doc_qa.init_cfg(llm_model=llm_model,
embedding_model=embedding_model, embedding_model=embedding_model,
llm_history_len=llm_history_len, llm_history_len=llm_history_len,
use_ptuning_v2=use_ptuning_v2, use_ptuning_v2=use_ptuning_v2,
use_lora = use_lora,
top_k=top_k,) top_k=top_k,)
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
print(model_status) print(model_status)
...@@ -246,6 +247,9 @@ with gr.Blocks(css=block_css) as demo: ...@@ -246,6 +247,9 @@ with gr.Blocks(css=block_css) as demo:
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2, use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
label="使用p-tuning-v2微调过的模型", label="使用p-tuning-v2微调过的模型",
interactive=True) interactive=True)
use_lora = gr.Checkbox(USE_LORA,
label="使用lora微调的权重",
interactive=True)
embedding_model = gr.Radio(embedding_model_dict_list, embedding_model = gr.Radio(embedding_model_dict_list,
label="Embedding 模型", label="Embedding 模型",
value=EMBEDDING_MODEL, value=EMBEDDING_MODEL,
...@@ -259,7 +263,7 @@ with gr.Blocks(css=block_css) as demo: ...@@ -259,7 +263,7 @@ with gr.Blocks(css=block_css) as demo:
load_model_button = gr.Button("重新加载模型") load_model_button = gr.Button("重新加载模型")
load_model_button.click(reinit_model, load_model_button.click(reinit_model,
show_progress=True, show_progress=True,
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot], inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot],
outputs=chatbot outputs=chatbot
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论