提交 33bbb477 作者: glide-the

llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例, 定义checkpoint名称和远程路径

loader.py: 模型重载
定义 generatorAnswer 增加 AnswerResultStream
   定义generate_with_callback收集器,在每次响应时将队列数据同步到AnswerResult
requirements.txt 变更项目依赖
上级 c3924b2e
...@@ -12,10 +12,14 @@ from tqdm import tqdm ...@@ -12,10 +12,14 @@ from tqdm import tqdm
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
from loader import UnstructuredPaddleImageLoader from loader import UnstructuredPaddleImageLoader
from loader import UnstructuredPaddlePDFLoader from loader import UnstructuredPaddlePDFLoader
from models.base import (BaseAnswer,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from models.loader.args import parser
from models.loader import LoaderCheckPoint
import models.shared as shared
DEVICE_ = EMBEDDING_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
def load_file(filepath, sentence_size=SENTENCE_SIZE): def load_file(filepath, sentence_size=SENTENCE_SIZE):
...@@ -132,7 +136,7 @@ def similarity_search_with_score_by_vector( ...@@ -132,7 +136,7 @@ def similarity_search_with_score_by_vector(
class LocalDocQA: class LocalDocQA:
llm: object = None llm: BaseAnswer = None
embeddings: object = None embeddings: object = None
top_k: int = VECTOR_SEARCH_TOP_K top_k: int = VECTOR_SEARCH_TOP_K
chunk_size: int = CHUNK_SIZE chunk_size: int = CHUNK_SIZE
...@@ -142,23 +146,10 @@ class LocalDocQA: ...@@ -142,23 +146,10 @@ class LocalDocQA:
def init_cfg(self, def init_cfg(self,
embedding_model: str = EMBEDDING_MODEL, embedding_model: str = EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE, embedding_device=EMBEDDING_DEVICE,
llm_history_len: int = LLM_HISTORY_LEN, llm_model: BaseAnswer = None,
llm_model: str = LLM_MODEL,
llm_device=LLM_DEVICE,
top_k=VECTOR_SEARCH_TOP_K, top_k=VECTOR_SEARCH_TOP_K,
use_ptuning_v2: bool = USE_PTUNING_V2,
use_lora: bool = USE_LORA,
): ):
if llm_model.startswith('moss'): self.llm = llm_model
from models.moss_llm import MOSS
self.llm = MOSS()
else:
from models.chatglm_llm import ChatGLM
self.llm = ChatGLM()
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora)
self.llm.history_len = llm_history_len
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
model_kwargs={'device': embedding_device}) model_kwargs={'device': embedding_device})
self.top_k = top_k self.top_k = top_k
...@@ -259,16 +250,16 @@ class LocalDocQA: ...@@ -259,16 +250,16 @@ class LocalDocQA:
torch_gc() torch_gc()
prompt = generate_prompt(related_docs_with_score, query) prompt = generate_prompt(related_docs_with_score, query)
for result, history in self.llm._call(prompt=prompt, for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
history=chat_history, streaming=streaming):
streaming=streaming): resp = answer_result.llm_output["answer"]
torch_gc() history = answer_result.history
history[-1][0] = query history[-1][0] = query
response = {"query": query, response = {"query": query,
"result": result, "result": resp,
"source_documents": related_docs_with_score} "source_documents": related_docs_with_score}
yield response, history yield response, history
torch_gc()
# query 查询内容 # query 查询内容
# vs_path 知识库路径 # vs_path 知识库路径
...@@ -297,10 +288,19 @@ class LocalDocQA: ...@@ -297,10 +288,19 @@ class LocalDocQA:
if __name__ == "__main__": if __name__ == "__main__":
# 初始化消息
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
local_doc_qa = LocalDocQA() local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg() local_doc_qa.init_cfg(llm_model=llm_model_ins)
query = "本项目使用的embedding模型是什么,消耗多少显存" query = "本项目使用的embedding模型是什么,消耗多少显存"
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa" vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
last_print_len = 0 last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path, vs_path=vs_path,
......
...@@ -22,14 +22,54 @@ EMBEDDING_MODEL = "text2vec" ...@@ -22,14 +22,54 @@ EMBEDDING_MODEL = "text2vec"
# Embedding running device # Embedding running device
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# supported LLM models # supported LLM models
"""
llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
"""
llm_model_dict = { llm_model_dict = {
"chatyuan": "ClueAI/ChatYuan-large-v2", "chatglm-6b-int4-qe": {
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe", "name": "chatglm-6b-int4-qe",
"chatglm-6b-int4": "THUDM/chatglm-6b-int4", "remote-checkpoint": "THUDM/chatglm-6b-int4-qe",
"chatglm-6b-int8": "THUDM/chatglm-6b-int8", "path": None,
"chatglm-6b": "THUDM/chatglm-6b", "provides": "ChatGLM"
"moss": "fnlp/moss-moon-003-sft", },
"chatglm-6b-int4": {
"name": "chatglm-6b-int4",
"remote-checkpoint": "THUDM/chatglm-6b-int4",
"path": None,
"provides": "ChatGLM"
},
"chatglm-6b": {
"name": "chatglm-6b",
"remote-checkpoint": "THUDM/chatglm-6b-int4",
"path": None,
"provides": "ChatGLM"
},
"llama-7b-hf": {
"name": "llama-7b-hf",
"remote-checkpoint": "llama-7b-hf",
"path": None,
"provides": "LLamaLLM"
},
"vicuna-13b-hf": {
"name": "vicuna-13b-hf",
"remote-checkpoint": "vicuna-13b-hf",
"path": None,
"provides": "LLamaLLM"
},
"chatyuan": {
"name": "chatyuan",
"remote-checkpoint": "ClueAI/ChatYuan-large-v2",
"path": None,
"provides": None
},
"chatglm-6b-int8":{
"name": "chatglm-6b-int8",
"remote-checkpoint": "THUDM/chatglm-6b-int8",
"path": None,
"provides": "ChatGLM"
},
} }
# LLM model name # LLM model name
......
from .fastchat_api import *
\ No newline at end of file
"""
Conversation prompt template.
Now we support
- Vicuna
- Koala
- OpenAssistant/oasst-sft-1-pythia-12b
- StabilityAI/stablelm-tuned-alpha-7b
- databricks/dolly-v2-12b
- THUDM/chatglm-6b
- Alpaca/LLaMa
"""
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
OASST_PYTHIA = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
# Used for gradio server
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system
for role, message in self.messages:
if message:
ret += self.sep + " " + role + ": " + message
else:
ret += self.sep + " " + role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ":\n" + message + seps[i % 2]
if i % 2 == 1:
ret += "\n\n"
else:
ret += role + ":\n"
return ret
elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
ret = self.system
for role, message in self.messages:
if message:
ret += role + message + self.sep
else:
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id,
)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
conv_one_shot = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
(
"Human",
"What are the key differences between renewable and non-renewable energy sources?",
),
(
"Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.",
),
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_koala_v1 = Conversation(
system="BEGINNING OF CONVERSATION:",
roles=("USER", "GPT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_dolly = Conversation(
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
roles=("### Instruction", "### Response"),
messages=(),
offset=0,
sep_style=SeparatorStyle.DOLLY,
sep="\n\n",
sep2="### End",
)
conv_oasst = Conversation(
system="",
roles=("<|prompter|>", "<|assistant|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.OASST_PYTHIA,
sep="<|endoftext|>",
)
conv_stablelm = Conversation(
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
""",
roles=("<|USER|>", "<|ASSISTANT|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.OASST_PYTHIA,
sep="",
)
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1.1": conv_vicuna_v1_1,
"koala_v1": conv_koala_v1,
"dolly": conv_dolly,
"oasst": conv_oasst,
}
def get_default_conv_template(model_name):
model_name = model_name.lower()
if "vicuna" in model_name or "output" in model_name:
return conv_vicuna_v1_1
elif "koala" in model_name:
return conv_koala_v1
elif "dolly-v2" in model_name:
return conv_dolly
elif "oasst" in model_name and "pythia" in model_name:
return conv_oasst
elif "stablelm" in model_name:
return conv_stablelm
return conv_one_shot
def compute_skip_echo_len(model_name, conv, prompt):
model_name = model_name.lower()
if "chatglm" in model_name:
skip_echo_len = len(conv.messages[-2][1]) + 1
elif "dolly-v2" in model_name:
special_toks = ["### Instruction:", "### Response:", "### End"]
skip_echo_len = len(prompt)
for tok in special_toks:
skip_echo_len -= prompt.count(tok) * len(tok)
elif "oasst" in model_name and "pythia" in model_name:
special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
skip_echo_len = len(prompt)
for tok in special_toks:
skip_echo_len -= prompt.count(tok) * len(tok)
elif "stablelm" in model_name:
special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
skip_echo_len = len(prompt)
for tok in special_toks:
skip_echo_len -= prompt.count(tok) * len(tok)
else:
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
return skip_echo_len
if __name__ == "__main__":
print(default_conversation.get_prompt())
from .chatglm_llm import * from .chatglm_llm import ChatGLM
\ No newline at end of file from .llama_llm import LLamaLLM
from .moss_llm import MOSSLLM
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
import asyncio
from argparse import Namespace
from models.loader.args import parser
from models.loader import LoaderCheckPoint
from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
import models.shared as shared
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from typing import List, Set
class CustomLLMSingleActionAgent(ZeroShotAgent):
allowed_tools: List[str]
def __init__(self, *args, **kwargs):
super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs)
self.allowed_tools = kwargs['allowed_tools']
def get_allowed_tools(self) -> Set[str]:
return set(self.allowed_tools)
async def dispatch(args: Namespace):
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
template = """This is a conversation between a human and a bot:
{chat_history}
Write a summary of the conversation for {input}:
"""
prompt = PromptTemplate(
input_variables=["input", "chat_history"],
template=template
)
memory = ConversationBufferMemory(memory_key="chat_history")
readonlymemory = ReadOnlySharedMemory(memory=memory)
summry_chain = LLMChain(
llm=llm_model_ins,
prompt=prompt,
verbose=True,
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
)
tools = [
Tool(
name="Summary",
func=summry_chain.run,
description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
)
]
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
suffix = """Begin!
Question: {input}
{agent_scratchpad}"""
prompt = CustomLLMSingleActionAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=["input", "agent_scratchpad"]
)
tool_names = [tool.name for tool in tools]
llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt)
agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names)
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools)
agent_chain.run(input="你好")
agent_chain.run(input="你是谁?")
agent_chain.run(input="我们之前聊了什么?")
if __name__ == '__main__':
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit'])
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(dispatch(args))
from abc import ABC, abstractmethod
from typing import Optional, List
import traceback
from collections import deque
from queue import Queue
from threading import Thread
import torch
import transformers
from models.loader import LoaderCheckPoint
class ListenerToken:
"""
观测结果
"""
input_ids: torch.LongTensor
_scores: torch.FloatTensor
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
self.input_ids = input_ids
self._scores = _scores
class AnswerResult:
"""
消息实体
"""
history: List[List[str]] = []
llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream:
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, answerResult: AnswerResult):
if self.callback_func is not None:
self.callback_func(answerResult)
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
"""
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
"""
listenerQueue: deque = deque(maxlen=1)
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}):
self.mfunc = func
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
结束条件包含如下
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
3、模型预测出错
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象,
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
2、消息为self.sentinel标识,抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
异步线程检测stop_now属性被更新,抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if self.stop_now:
raise ValueError
self.q.put(val)
def gen():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
self.thread = Thread(target=gen)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
"""
暂无实现
:return:
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
""" break 后会执行 """
self.stop_now = True
class BaseAnswer(ABC):
"""上层业务包装器.用于结果生成统一api调用"""
@property
@abstractmethod
def _check_point(self) -> LoaderCheckPoint:
"""Return _check_point of llm."""
@property
@abstractmethod
def _history_len(self) -> int:
"""Return _history_len of llm."""
@abstractmethod
def set_history_len(self, history_len: int) -> None:
"""Return _history_len of llm."""
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
def generate_with_callback(callback=None, **kwargs):
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
self._generate_answer(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs)
"""
eos_token_id是指定token(例如,"</s>"),
用于表示序列的结束。在生成文本任务中,生成器在生成序列时,将不断地生成token,直到生成此特殊的eos_token_id,表示序列生成已经完成。
在Hugging Face Transformer模型中,eos_token_id是由tokenizer自动添加到输入中的。
在模型生成输出时,如果模型生成了eos_token_id,则生成过程将停止并返回生成的序列。
"""
eos_token_ids = [
self._check_point.tokenizer.eos_token_id] if self._check_point.tokenizer.eos_token_id is not None else []
with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
for answerResult in generator:
if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult
@abstractmethod
def _generate_answer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False,
generate_with_callback: AnswerResultStream = None) -> None:
pass
import gc
import traceback
from queue import Queue
from threading import Thread
import threading
from typing import Optional, List, Dict, Any
from collections import deque
import torch
import transformers
from models.extensions.thread_with_exception import ThreadWithException
import models.shared as shared
class LimitedLengthDict(dict):
def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen
self._keys = deque()
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
if key not in self:
if self.maxlen is not None and len(self) >= self.maxlen:
oldest_key = self._keys.popleft()
if oldest_key in self:
del self[oldest_key]
self._keys.append(key)
super().__setitem__(key, value)
class FixedLengthQueue:
# 停止符号列表
stop_sequence: Optional[str] = []
# 缓冲区
max_length: int = 0
# 缓冲区容器
queue: deque = None
# 输入区容器
queue_in: LimitedLengthDict[int, str] = {}
# 输出区容器
queue_out: Dict[int, str] = {}
def __new__(cls, *args, **kwargs):
# 创建新的实例
instance = super().__new__(cls)
# 在这里可以对实例进行额外的设置
return instance
def __init__(self, stop_sequence):
if stop_sequence is None:
self.stop_sequence = []
self.max_length = 0
elif isinstance(stop_sequence, str):
self.stop_sequence = [stop_sequence]
self.max_length = 1
else:
self.stop_sequence = stop_sequence
self.max_length = len(''.join(stop_sequence))
self.queue = deque(maxlen=self.max_length)
self.queue.clear()
self.queue_in.clear()
self.queue_out.clear()
def add(self, index, item):
self.queue_in[index] = item
def _add_out(self, index, item):
self.queue_out[index] = item
def put_replace_out(self, index):
return self.queue_out[index]
def contains_replace_sequence(self):
"""
替换字符
:return:
"""
for key, value in self.queue_in.items():
word_index = value.rfind(":")
if word_index != -1:
value = value.replace(":", ":")
word_index = value.rfind("[")
if word_index != -1:
value = value.replace("[", "")
word_index = value.rfind("]")
if word_index != -1:
value = value.replace("]", "")
self._add_out(key, value)
def contains_stop_sequence(self):
# 截取固定大小的数据判断
self.queue.clear()
last_three_keys = list(self.queue_out.keys())[-self.max_length:]
joined_queue = ''.join([self.queue_out[key] for key in last_three_keys])
for char in joined_queue:
self.queue.append(char)
joined_queue = ''.join(self.queue)
# Initialize a variable to store the index of the last found stop string
last_stop_str_index = -1
# Iterate through the stop string list
for stop_word in self.stop_sequence:
# Find the last occurrence of the stop string in the output
stop_word_index = joined_queue.rfind(stop_word)
# If the stop string is found, compare the index with the previously found index
if stop_word_index != -1 and stop_word_index > last_stop_str_index:
last_stop_str_index = stop_word_index
# Handle the last found stop string index here
return last_stop_str_index
def __repr__(self):
return str(self.queue)
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: list, starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
for i in range(len(self.sentinel_token_ids)):
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
continue
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
return True
return False
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if shared.stop_everything:
raise ValueError
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
thread: ThreadWithException = None
def __new__(cls, *args, **kwargs):
# 创建新的实例
instance = super().__new__(cls)
# 在这里可以对实例进行额外的设置
return instance
def __init__(self, func, kwargs={}, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
def _callback(val):
if shared.stop_everything:
raise ValueError
self.q.put(val)
def gen():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
print("print(ValueError)")
except:
traceback.print_exc()
print("traceback.print_exc()")
self.q.put(self.sentinel)
self.thread = ThreadWithException(target=gen)
self.thread.start()
def __iter__(self):
shared.stop_everything = False
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
shared.stop_everything = False
self.q.empty()
shared.loaderCheckPoint.clear_torch_cache()
def __enter__(self):
shared.stop_everything = False
return self
def __exit__(self, exc_type, exc_val, exc_tb):
shared.stop_everything = True
shared.loaderCheckPoint.clear_torch_cache()
self.thread.raise_exception()
import gc
import traceback
import torch
# This iterator returns the extensions in the order specified in the command-line
def iterator():
state_extensions = {}
for name in sorted(state_extensions, key=lambda x: state_extensions[x][1]):
if state_extensions[name][0]:
yield getattr(extensions, name).script, name
\ No newline at end of file
'''
Based on
https://github.com/abetlen/llama-cpp-python
Documentation:
https://abetlen.github.io/llama-cpp-python/
'''
from llama_cpp import Llama, LlamaCache
from modules import shared
from modules.callbacks import Iteratorize
class LlamaCppModel:
def __init__(self):
self.initialized = False
@classmethod
def from_pretrained(self, path):
result = self()
params = {
'model_path': str(path),
'n_ctx': 2048,
'seed': 0,
'n_threads': shared.args.threads or None
}
self.model = Llama(**params)
self.model.set_cache(LlamaCache)
# This is ugly, but the model and the tokenizer are the same object in this library.
return result, result
def encode(self, string):
if type(string) is str:
string = string.encode()
return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
if type(context) is str:
context = context.encode()
tokens = self.model.tokenize(context)
output = b""
count = 0
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
text = self.model.detokenize([token])
output += text
if callback:
callback(text.decode())
count += 1
if count >= token_count or (token == self.model.token_eos()):
break
return output.decode()
def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = ''
for token in generator:
reply += token
yield reply
# Python program raising
# exceptions in a python
# thread
import threading
import ctypes
import time
class ThreadWithException(threading.Thread):
def get_id(self):
return self.ident
def raise_exception(self):
"""raises the exception, performs cleanup if needed"""
try:
thread_id = self.get_id()
tid = ctypes.c_long(thread_id)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(SystemExit))
if res == 0:
# pass
raise ValueError("invalid thread id")
elif res != 1:
# """if it returns a number greater than one, you're in trouble,
# and you should call it again with exc=NULL to revert the effect"""
ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
raise SystemError("PyThreadState_SetAsyncExc failed")
except Exception as err:
print(err)
from .loader import *
import argparse
import os
# Additional argparse types
def path(string):
if not string:
return ''
s = os.path.expanduser(string)
if not os.path.exists(s):
raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"')
return s
def file_path(string):
if not string:
return ''
s = os.path.expanduser(string)
if not os.path.isfile(s):
raise argparse.ArgumentTypeError(f'No such file: "{string}"')
return s
def dir_path(string):
if not string:
return ''
s = os.path.expanduser(string)
if not os.path.isdir(s):
raise argparse.ArgumentTypeError(f'No such directory: "{string}"')
return s
parser = argparse.ArgumentParser(prog='langchina-ChatGLM',
description='基于langchain和chatGML的LLM文档阅读器')
parser.add_argument('--no-remote-model', action='store_true', default=False, help='remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model`')
parser.add_argument('--model', type=str, default='chatglm-6b', help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--model-dir", type=str, default='model/', help="Path to directory with all the models")
parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
# Accelerate/transformers
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
args = parser.parse_args([])
# Generares dict with a default value for each argument
DEFAULT_ARGS = vars(args)
import json from abc import ABC
from langchain.llms.base import LLM
from typing import List, Dict, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.modeling_utils import no_init_weights
from transformers.utils import ContextManagers
import torch
from configs.model_config import *
from utils import torch_gc
from accelerate import init_empty_weights from langchain.llms.base import LLM
from accelerate.utils import get_balanced_memory, infer_auto_device_map from typing import Optional, List
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
DEVICE_ = LLM_DEVICE import torch
DEVICE_ID = "0" if torch.cuda.is_available() else None
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
META_INSTRUCTION = \ META_INSTRUCTION = \
"""You are an AI assistant whose name is MOSS. """You are an AI assistant whose name is MOSS.
...@@ -30,45 +24,40 @@ META_INSTRUCTION = \ ...@@ -30,45 +24,40 @@ META_INSTRUCTION = \
""" """
def auto_configure_device_map() -> Dict[str, int]: class MOSSLLM(BaseAnswer, LLM, ABC):
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
pretrained_model_name_or_path=llm_model_dict['moss'])
with ContextManagers([no_init_weights(_enable=True), init_empty_weights()]):
model_config = AutoConfig.from_pretrained(llm_model_dict['moss'], trust_remote_code=True)
model = cls(model_config)
max_memory = get_balanced_memory(model, dtype=torch.int8 if LOAD_IN_8BIT else None,
low_zero=False, no_split_module_classes=model._no_split_modules)
device_map = infer_auto_device_map(
model, dtype=torch.float16 if not LOAD_IN_8BIT else torch.int8, max_memory=max_memory,
no_split_module_classes=model._no_split_modules)
device_map["transformer.wte"] = 0
device_map["transformer.drop"] = 0
device_map["transformer.ln_f"] = 0
device_map["lm_head"] = 0
return device_map
class MOSS(LLM):
max_token: int = 2048 max_token: int = 2048
temperature: float = 0.7 temperature: float = 0.7
top_p = 0.8 top_p = 0.8
# history = [] # history = []
tokenizer: object = None checkPoint: LoaderCheckPoint = None
model: object = None
history_len: int = 10 history_len: int = 10
def __init__(self): def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__() super().__init__()
self.checkPoint = checkPoint
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "MOSS" return "MOSS"
def _call(self, @property
prompt: str, def _check_point(self) -> LoaderCheckPoint:
history: List[List[str]] = [], return self.checkPoint
streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
@property
def set_history_len(self) -> int:
return self.history_len
def _set_history_len(self, history_len: int) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass
def _generate_answer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False,
generate_with_callback: AnswerResultStream = None) -> None:
if len(history) > 0: if len(history) > 0:
history = history[-self.history_len:-1] if self.history_len > 0 else [] history = history[-self.history_len:-1] if self.history_len > 0 else []
prompt_w_history = str(history) prompt_w_history = str(history)
...@@ -77,9 +66,9 @@ class MOSS(LLM): ...@@ -77,9 +66,9 @@ class MOSS(LLM):
prompt_w_history = META_INSTRUCTION prompt_w_history = META_INSTRUCTION
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>' prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
inputs = self.tokenizer(prompt_w_history, return_tensors="pt") inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.checkPoint.model.generate(
inputs.input_ids.cuda(), inputs.input_ids.cuda(),
attention_mask=inputs.attention_mask.cuda(), attention_mask=inputs.attention_mask.cuda(),
max_length=self.max_token, max_length=self.max_token,
...@@ -92,78 +81,8 @@ class MOSS(LLM): ...@@ -92,78 +81,8 @@ class MOSS(LLM):
eos_token_id=106068, eos_token_id=106068,
pad_token_id=self.tokenizer.pad_token_id) pad_token_id=self.tokenizer.pad_token_id)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
torch_gc() self.checkPoint.clear_torch_cache()
history += [[prompt, response]] history += [[prompt, response]]
yield response, history yield response, history
torch_gc()
def load_model(self,
model_name_or_path: str = "fnlp/moss-moon-003-sft",
llm_device=LLM_DEVICE,
use_ptuning_v2=False,
use_lora=False,
device_map: Optional[Dict[str, int]] = None,
**kwargs):
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if use_ptuning_v2:
try:
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
prefix_encoder_config = json.loads(prefix_encoder_file.read())
prefix_encoder_file.close()
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
except Exception as e:
print(e)
print("加载PrefixEncoder config.json失败")
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
# accelerate自动多卡部署
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=model_config,
load_in_8bit=LOAD_IN_8BIT, trust_remote_code=True,
device_map=auto_configure_device_map(), **kwargs)
if LLM_LORA_PATH and use_lora:
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
else:
self.model = self.model.float().to(llm_device)
if LLM_LORA_PATH and use_lora:
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
if use_ptuning_v2:
try:
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
except Exception as e:
print(e)
print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval()
if __name__ == "__main__":
llm = MOSS()
llm.load_model(model_name_or_path=llm_model_dict['moss'],
llm_device=LLM_DEVICE, )
last_print_len = 0
# for resp, history in llm._call("你好", streaming=True):
# print(resp[last_print_len:], end="", flush=True)
# last_print_len = len(resp)
for resp, history in llm._call("你好", streaming=False):
print(resp)
import time
time.sleep(10)
pass
import sys
from models.loader.args import parser
from models.loader import LoaderCheckPoint
from configs.model_config import (llm_model_dict, LLM_MODEL)
from models.base import BaseAnswer
"""迭代器是否停止状态"""
stop_everything = False
loaderCheckPoint: LoaderCheckPoint = None
def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False) -> BaseAnswer:
"""
init llm_model_ins LLM
:param llm_model: model_name
:param no_remote_model: remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model
:param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder
:return:
"""
pre_model_name = loaderCheckPoint.model_name
llm_model_info = llm_model_dict[pre_model_name]
if no_remote_model:
loaderCheckPoint.no_remote_model = no_remote_model
if use_ptuning_v2:
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
if llm_model:
llm_model_info = llm_model_dict[llm_model]
if loaderCheckPoint.no_remote_model:
loaderCheckPoint.model_name = llm_model_info['name']
else:
loaderCheckPoint.model_name = llm_model_info['remote-checkpoint']
loaderCheckPoint.model_path = llm_model_info['path']
loaderCheckPoint.reload_model()
provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
return modelInsLLM
...@@ -17,6 +17,8 @@ fastapi ...@@ -17,6 +17,8 @@ fastapi
uvicorn uvicorn
peft peft
pypinyin pypinyin
bitsandbytes
click~=8.1.3 click~=8.1.3
tabulate tabulate
bitsandbytes; platform_system != "Windows"
llama-cpp-python==0.1.34; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
import gradio as gr import gradio as gr
import os import os
import shutil import shutil
from chains.local_doc_qa import LocalDocQA from chains.local_doc_qa import LocalDocQA
from configs.model_config import * from configs.model_config import *
import nltk import nltk
from models.base import (BaseAnswer,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import models.shared as shared
from models.loader.args import parser
from models.loader import LoaderCheckPoint
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
...@@ -69,7 +77,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR ...@@ -69,7 +77,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query, yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], "" "请选择知识库后进行测试,当前未选择知识库。"]], ""
else: else:
for resp, history in local_doc_qa.llm._call(query, history, streaming=streaming): for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming):
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][-1] = resp + ( history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, "" yield history, ""
...@@ -77,10 +89,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR ...@@ -77,10 +89,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME) flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
def init_model(): def init_model(llm_model: BaseAnswer = None):
try: try:
local_doc_qa.init_cfg() local_doc_qa.init_cfg(llm_model=llm_model)
local_doc_qa.llm._call("你好") generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply) logger.info(reply)
return reply return reply
...@@ -95,14 +109,13 @@ def init_model(): ...@@ -95,14 +109,13 @@ def init_model():
return reply return reply
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history): def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
try: try:
local_doc_qa.init_cfg(llm_model=llm_model, llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
llm_model_ins.history_len = llm_history_len
local_doc_qa.init_cfg(llm_model=llm_model_ins,
embedding_model=embedding_model, embedding_model=embedding_model,
llm_history_len=llm_history_len, top_k=top_k)
use_ptuning_v2=use_ptuning_v2,
use_lora=use_lora,
top_k=top_k, )
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(model_status) logger.info(model_status)
except Exception as e: except Exception as e:
...@@ -219,7 +232,17 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI! ...@@ -219,7 +232,17 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
知识库暂不支持文件删除,该功能将在后续版本中推出。 知识库暂不支持文件删除,该功能将在后续版本中推出。
""" """
model_status = init_model() # 初始化消息
args = None
args = parser.parse_args()
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
model_status = init_model(llm_model=llm_model_ins)
default_theme_args = dict( default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'], font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
...@@ -399,6 +422,10 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -399,6 +422,10 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
label="LLM 模型", label="LLM 模型",
value=LLM_MODEL, value=LLM_MODEL,
interactive=True) interactive=True)
no_remote_model = gr.Checkbox(shared.LoaderCheckPoint.no_remote_model,
label="加载本地模型",
interactive=True)
llm_history_len = gr.Slider(0, 10, llm_history_len = gr.Slider(0, 10,
value=LLM_HISTORY_LEN, value=LLM_HISTORY_LEN,
step=1, step=1,
...@@ -418,7 +445,7 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as ...@@ -418,7 +445,7 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
label="向量匹配 top k", interactive=True) label="向量匹配 top k", interactive=True)
load_model_button = gr.Button("重新加载模型") load_model_button = gr.Button("重新加载模型")
load_model_button.click(reinit_model, show_progress=True, load_model_button.click(reinit_model, show_progress=True,
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora,
top_k, chatbot], outputs=chatbot) top_k, chatbot], outputs=chatbot)
(demo (demo
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论