Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
J
jinchat-server
概览
概览
详情
活动
周期分析
版本库
存储库
文件
提交
分支
标签
贡献者
分支图
比较
统计图
问题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
aigc-pioneer
jinchat-server
Commits
760abab1
提交
760abab1
authored
7月 14, 2023
作者:
hzg0601
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'dev' of github.com:hzg0601/langchain-ChatGLM-annotation into dev
merge upstream dev
上级
3d082bf5
57cb6b05
显示空白字符变更
内嵌
并排
正在显示
18 个修改的文件
包含
529 行增加
和
349 行删除
+529
-349
.gitignore
.gitignore
+0
-0
api.py
api.py
+4
-3
local_doc_qa.py
chains/local_doc_qa.py
+13
-9
model_config.py
configs/model_config.py
+28
-19
qr_code_36.jpg
img/qr_code_36.jpg
+0
-0
__init__.py
models/__init__.py
+4
-4
__init__.py
models/base/__init__.py
+5
-3
base.py
models/base/base.py
+151
-13
chatglm_llm.py
models/chatglm_llm.py
+68
-31
fastchat_openai_llm.py
models/fastchat_openai_llm.py
+71
-65
llama_llm.py
models/llama_llm.py
+75
-84
args.py
models/loader/args.py
+1
-1
loader.py
models/loader/loader.py
+28
-36
moss_llm.py
models/moss_llm.py
+52
-21
shared.py
models/shared.py
+2
-3
test_fastchat_openai_llm.py
test/models/test_fastchat_openai_llm.py
+0
-39
webui.py
webui.py
+19
-12
webui_st.py
webui_st.py
+8
-6
没有找到文件。
.gitignore
浏览文件 @
760abab1
api.py
浏览文件 @
760abab1
...
...
@@ -384,8 +384,10 @@ async def chat(
],
),
):
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
question
,
history
=
history
,
streaming
=
True
):
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
question
,
"history"
:
history
,
"streaming"
:
True
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
pass
...
...
@@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
global
local_doc_qa
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
app
=
FastAPI
()
# Add CORS middleware to allow all origins
...
...
chains/local_doc_qa.py
浏览文件 @
760abab1
...
...
@@ -18,6 +18,7 @@ from agent import bing_search
from
langchain.docstore.document
import
Document
from
functools
import
lru_cache
from
textsplitter.zh_title_enhance
import
zh_title_enhance
from
langchain.chains.base
import
Chain
# patch HuggingFaceEmbeddings to make it hashable
...
...
@@ -119,7 +120,7 @@ def search_result2docs(search_results):
class
LocalDocQA
:
llm
:
BaseAnswer
=
None
llm
_model_chain
:
Chain
=
None
embeddings
:
object
=
None
top_k
:
int
=
VECTOR_SEARCH_TOP_K
chunk_size
:
int
=
CHUNK_SIZE
...
...
@@ -129,10 +130,10 @@ class LocalDocQA:
def
init_cfg
(
self
,
embedding_model
:
str
=
EMBEDDING_MODEL
,
embedding_device
=
EMBEDDING_DEVICE
,
llm_model
:
BaseAnswer
=
None
,
llm_model
:
Chain
=
None
,
top_k
=
VECTOR_SEARCH_TOP_K
,
):
self
.
llm
=
llm_model
self
.
llm
_model_chain
=
llm_model
self
.
embeddings
=
HuggingFaceEmbeddings
(
model_name
=
embedding_model_dict
[
embedding_model
],
model_kwargs
=
{
'device'
:
embedding_device
})
self
.
top_k
=
top_k
...
...
@@ -236,8 +237,10 @@ class LocalDocQA:
else
:
prompt
=
query
for
answer_result
in
self
.
llm
.
generatorAnswer
(
prompt
=
prompt
,
history
=
chat_history
,
streaming
=
streaming
):
answer_result_stream_result
=
self
.
llm_model_chain
(
{
"prompt"
:
prompt
,
"history"
:
chat_history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
[
-
1
][
0
]
=
query
...
...
@@ -276,8 +279,10 @@ class LocalDocQA:
result_docs
=
search_result2docs
(
results
)
prompt
=
generate_prompt
(
result_docs
,
query
)
for
answer_result
in
self
.
llm
.
generatorAnswer
(
prompt
=
prompt
,
history
=
chat_history
,
streaming
=
streaming
):
answer_result_stream_result
=
self
.
llm_model_chain
(
{
"prompt"
:
prompt
,
"history"
:
chat_history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
[
-
1
][
0
]
=
query
...
...
@@ -296,7 +301,7 @@ class LocalDocQA:
def
update_file_from_vector_store
(
self
,
filepath
:
str
or
List
[
str
],
vs_path
,
docs
:
List
[
Document
],):
docs
:
List
[
Document
],
):
vector_store
=
load_vector_store
(
vs_path
,
self
.
embeddings
)
status
=
vector_store
.
update_doc
(
filepath
,
docs
)
return
status
...
...
@@ -320,7 +325,6 @@ if __name__ == "__main__":
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
local_doc_qa
=
LocalDocQA
()
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
)
...
...
configs/model_config.py
浏览文件 @
760abab1
...
...
@@ -37,55 +37,55 @@ llm_model_dict = {
"name"
:
"chatglm-6b-int4-qe"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int4-qe"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm-6b-int4"
:
{
"name"
:
"chatglm-6b-int4"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int4"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm-6b-int8"
:
{
"name"
:
"chatglm-6b-int8"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int8"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm-6b"
:
{
"name"
:
"chatglm-6b"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm2-6b"
:
{
"name"
:
"chatglm2-6b"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm2-6b-int4"
:
{
"name"
:
"chatglm2-6b-int4"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b-int4"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm2-6b-int8"
:
{
"name"
:
"chatglm2-6b-int8"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b-int8"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatyuan"
:
{
"name"
:
"chatyuan"
,
"pretrained_model_name"
:
"ClueAI/ChatYuan-large-v2"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
"moss"
:
{
"name"
:
"moss"
,
"pretrained_model_name"
:
"fnlp/moss-moon-003-sft"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
"moss-int4"
:
{
"name"
:
"moss"
,
...
...
@@ -97,7 +97,13 @@ llm_model_dict = {
"name"
:
"vicuna-13b-hf"
,
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"provides"
:
"LLamaLLM"
"provides"
:
"LLamaLLMChain"
},
"vicuna-7b-hf"
:
{
"name"
:
"vicuna-13b-hf"
,
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"provides"
:
"LLamaLLMChain"
},
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
...
...
@@ -107,7 +113,7 @@ llm_model_dict = {
"name"
:
"bloomz-7b1"
,
"pretrained_model_name"
:
"bigscience/bloomz-7b1"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
...
...
@@ -116,14 +122,14 @@ llm_model_dict = {
"name"
:
"bloom-3b"
,
"pretrained_model_name"
:
"bigscience/bloom-3b"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
"baichuan-7b"
:
{
"name"
:
"baichuan-7b"
,
"pretrained_model_name"
:
"baichuan-inc/baichuan-7B"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
"ggml-vicuna-13b-1.1-q5"
:
{
...
...
@@ -137,7 +143,7 @@ llm_model_dict = {
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
"local_model_path"
:
f
'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/'''
,
"provides"
:
"LLamaLLM"
"provides"
:
"LLamaLLM
Chain
"
},
# 通过 fastchat 调用的模型请参考如下格式
...
...
@@ -145,7 +151,7 @@ llm_model_dict = {
"name"
:
"chatglm-6b"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"chatglm-6b"
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
,
# "name"修改为fastchat服务中的"api_base_url"
"api_key"
:
"EMPTY"
},
...
...
@@ -153,7 +159,7 @@ llm_model_dict = {
"name"
:
"chatglm2-6b"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"chatglm2-6b"
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
# "name"修改为fastchat服务中的"api_base_url"
},
...
...
@@ -162,7 +168,7 @@ llm_model_dict = {
"name"
:
"vicuna-13b-hf"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
,
# "name"修改为fastchat服务中的"api_base_url"
"api_key"
:
"EMPTY"
},
...
...
@@ -177,7 +183,7 @@ llm_model_dict = {
"openai-chatgpt-3.5"
:
{
"name"
:
"gpt-3.5-turbo"
,
"pretrained_model_name"
:
"gpt-3.5-turbo"
,
"provides"
:
"FastChatOpenAILLM"
,
"provides"
:
"FastChatOpenAILLM
Chain
"
,
"local_model_path"
:
None
,
"api_base_url"
:
"https://api.openapi.com/v1"
,
"api_key"
:
""
...
...
@@ -204,7 +210,10 @@ STREAMING = True
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2
=
False
PTUNING_DIR
=
'./ptuing-v2'
<<<<<<<
HEAD
=======
>>>>>>>
f68d347c25b4bdd07f293c65a6e44a673a11f614
# LLM running device
LLM_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"mps"
if
torch
.
backends
.
mps
.
is_available
()
else
"cpu"
...
...
@@ -233,7 +242,7 @@ LLM_HISTORY_LEN = 3
VECTOR_SEARCH_TOP_K
=
5
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
VECTOR_SEARCH_SCORE_THRESHOLD
=
0
VECTOR_SEARCH_SCORE_THRESHOLD
=
39
0
NLTK_DATA_PATH
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)),
"nltk_data"
)
...
...
img/qr_code_36.jpg
deleted
100644 → 0
浏览文件 @
3d082bf5
247.1 KB
models/__init__.py
浏览文件 @
760abab1
from
.chatglm_llm
import
ChatGLM
from
.llama_llm
import
LLamaLLM
from
.
moss_llm
import
MOSSLLM
from
.
fastchat_openai_llm
import
FastChatOpenAILLM
from
.chatglm_llm
import
ChatGLM
LLMChain
from
.llama_llm
import
LLamaLLM
Chain
from
.
fastchat_openai_llm
import
FastChatOpenAILLMChain
from
.
moss_llm
import
MOSSLLMChain
models/base/__init__.py
浏览文件 @
760abab1
from
models.base.base
import
(
AnswerResult
,
BaseAnswer
)
BaseAnswer
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
from
models.base.remote_rpc_model
import
(
RemoteRpcModel
)
__all__
=
[
"AnswerResult"
,
"BaseAnswer"
,
"RemoteRpcModel"
,
"AnswerResultStream"
,
"AnswerResultQueueSentinelTokenListenerQueue"
]
models/base/base.py
浏览文件 @
760abab1
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
import
traceback
from
collections
import
deque
from
queue
import
Queue
from
threading
import
Thread
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
models.loader
import
LoaderCheckPoint
import
torch
import
transformers
from
models.loader
import
LoaderCheckPoint
class
ListenerToken
:
"""
观测结果
"""
input_ids
:
torch
.
LongTensor
_scores
:
torch
.
FloatTensor
def
__init__
(
self
,
input_ids
:
torch
.
LongTensor
,
_scores
:
torch
.
FloatTensor
):
self
.
input_ids
=
input_ids
self
.
_scores
=
_scores
class
AnswerResult
:
...
...
@@ -16,6 +29,123 @@ class AnswerResult:
"""
history
:
List
[
List
[
str
]]
=
[]
llm_output
:
Optional
[
dict
]
=
None
listenerToken
:
ListenerToken
=
None
class
AnswerResultStream
:
def
__init__
(
self
,
callback_func
=
None
):
self
.
callback_func
=
callback_func
def
__call__
(
self
,
answerResult
:
AnswerResult
):
if
self
.
callback_func
is
not
None
:
self
.
callback_func
(
answerResult
)
class
AnswerResultQueueSentinelTokenListenerQueue
(
transformers
.
StoppingCriteria
):
"""
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
"""
listenerQueue
:
deque
=
deque
(
maxlen
=
1
)
def
__init__
(
self
):
transformers
.
StoppingCriteria
.
__init__
(
self
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
_scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self
.
listenerQueue
.
append
(
ListenerToken
(
input_ids
=
input_ids
,
_scores
=
_scores
))
return
False
class
Iteratorize
:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def
__init__
(
self
,
func
,
kwargs
=
{}):
self
.
mfunc
=
func
self
.
q
=
Queue
()
self
.
sentinel
=
object
()
self
.
kwargs
=
kwargs
self
.
stop_now
=
False
def
_callback
(
val
):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
结束条件包含如下
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
3、模型预测出错
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象,
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
2、消息为self.sentinel标识,抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
异步线程检测stop_now属性被更新,抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if
self
.
stop_now
:
raise
ValueError
self
.
q
.
put
(
val
)
def
gen
():
try
:
ret
=
self
.
mfunc
(
callback
=
_callback
,
**
self
.
kwargs
)
except
ValueError
:
pass
except
:
traceback
.
print_exc
()
pass
self
.
q
.
put
(
self
.
sentinel
)
self
.
thread
=
Thread
(
target
=
gen
)
self
.
thread
.
start
()
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
obj
=
self
.
q
.
get
(
True
,
None
)
if
obj
is
self
.
sentinel
:
raise
StopIteration
else
:
return
obj
def
__del__
(
self
):
"""
暂无实现
:return:
"""
pass
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
""" break 后会执行 """
self
.
stop_now
=
True
class
BaseAnswer
(
ABC
):
...
...
@@ -25,17 +155,25 @@ class BaseAnswer(ABC):
@abstractmethod
def
_check_point
(
self
)
->
LoaderCheckPoint
:
"""Return _check_point of llm."""
def
generatorAnswer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,)
->
Generator
[
Any
,
str
,
bool
]:
def
generate_with_callback
(
callback
=
None
,
**
kwargs
):
kwargs
[
'generate_with_callback'
]
=
AnswerResultStream
(
callback_func
=
callback
)
self
.
_generate_answer
(
**
kwargs
)
@property
@abstractmethod
def
_history_len
(
self
)
->
int
:
"""Return _history_len of llm."""
def
generate_with_streaming
(
**
kwargs
):
return
Iteratorize
(
generate_with_callback
,
kwargs
)
@abstractmethod
def
set_history_len
(
self
,
history_len
:
int
)
->
None
:
"""Return _history_len of llm."""
with
generate_with_streaming
(
inputs
=
inputs
,
run_manager
=
run_manager
)
as
generator
:
for
answerResult
in
generator
:
if
answerResult
.
listenerToken
:
output
=
answerResult
.
listenerToken
.
input_ids
yield
answerResult
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
@abstractmethod
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
pass
models/chatglm_llm.py
浏览文件 @
760abab1
from
abc
import
ABC
from
langchain.llms.base
import
LLM
from
typing
import
Optional
,
List
from
langchain.chains.base
import
Chain
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
)
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
class
ChatGLM
(
BaseAnswer
,
LLM
,
ABC
):
class
ChatGLM
LLMChain
(
BaseAnswer
,
Chain
,
ABC
):
max_token
:
int
=
10000
temperature
:
float
=
0.01
top_p
=
0.9
# 相关度
top_p
=
0.4
# 候选词数量
top_k
=
10
checkPoint
:
LoaderCheckPoint
=
None
# history = []
history_len
:
int
=
10
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
@property
def
_
llm
_type
(
self
)
->
str
:
return
"ChatGLM"
def
_
chain
_type
(
self
)
->
str
:
return
"ChatGLM
LLMChain
"
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
@property
def
_history_len
(
self
)
->
int
:
return
self
.
history_len
def
input_keys
(
self
)
->
List
[
str
]
:
"""Will be whatever keys the prompt expects.
def
set_history_len
(
self
,
history_len
:
int
=
10
)
->
None
:
self
.
history_len
=
history_len
:meta private:
"""
return
[
self
.
prompt_key
]
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
print
(
f
"__call:{prompt}"
)
response
,
_
=
self
.
checkPoint
.
model
.
chat
(
self
.
checkPoint
.
tokenizer
,
prompt
,
history
=
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
)
print
(
f
"response:{response}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
return
response
@property
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
:meta private:
"""
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list
=
transformers
.
StoppingCriteriaList
()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue
=
AnswerResultQueueSentinelTokenListenerQueue
()
stopping_criteria_list
.
append
(
listenerQueue
)
if
streaming
:
history
+=
[[]]
for
inum
,
(
stream_resp
,
_
)
in
enumerate
(
self
.
checkPoint
.
model
.
stream_chat
(
self
.
checkPoint
.
tokenizer
,
prompt
,
history
=
history
[
-
self
.
history_len
:
-
1
]
if
self
.
history_len
>
1
else
[],
history
=
history
[
-
self
.
history_len
:
-
1
]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
stopping_criteria
=
stopping_criteria_list
)):
# self.checkPoint.clear_torch_cache()
history
[
-
1
]
=
[
prompt
,
stream_resp
]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
stream_resp
}
yield
answer_result
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
generate_with_callback
(
answer_result
)
self
.
checkPoint
.
clear_torch_cache
()
else
:
response
,
_
=
self
.
checkPoint
.
model
.
chat
(
...
...
@@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt
,
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
stopping_criteria
=
stopping_criteria_list
)
self
.
checkPoint
.
clear_torch_cache
()
history
+=
[[
prompt
,
response
]]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
yield
answer_result
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
generate_with_callback
(
answer_result
)
models/fastchat_openai_llm.py
浏览文件 @
760abab1
from
abc
import
ABC
import
requests
from
typing
import
Optional
,
List
from
langchain.llms.base
import
LLM
from
langchain.chains.base
import
Chain
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
,
Collection
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
RemoteRpcModel
,
AnswerResult
)
from
typing
import
(
Collection
,
Dict
)
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
models.base
import
(
BaseAnswer
,
RemoteRpcModel
,
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
def
_build_message_template
()
->
Dict
[
str
,
str
]:
...
...
@@ -22,18 +22,42 @@ def _build_message_template() -> Dict[str, str]:
}
class
FastChatOpenAILLM
(
RemoteRpcModel
,
LLM
,
ABC
):
# 将历史对话数组转换为文本格式
def
build_message_list
(
query
,
history
:
List
[
List
[
str
]])
->
Collection
[
Dict
[
str
,
str
]]:
build_messages
:
Collection
[
Dict
[
str
,
str
]]
=
[]
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
old_query
system_build_message
=
_build_message_template
()
system_build_message
[
'role'
]
=
'system'
system_build_message
[
'content'
]
=
response
build_messages
.
append
(
user_build_message
)
build_messages
.
append
(
system_build_message
)
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
query
build_messages
.
append
(
user_build_message
)
return
build_messages
class
FastChatOpenAILLMChain
(
RemoteRpcModel
,
Chain
,
ABC
):
api_base_url
:
str
=
"http://localhost:8000/v1"
model_name
:
str
=
"chatglm-6b"
max_token
:
int
=
10000
temperature
:
float
=
0.01
top_p
=
0.9
checkPoint
:
LoaderCheckPoint
=
None
history
=
[]
#
history = []
history_len
:
int
=
10
api_key
:
str
=
""
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
,
# api_base_url:str="http://localhost:8000/v1",
...
...
@@ -44,19 +68,28 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
self
.
checkPoint
=
checkPoint
@property
def
_
llm
_type
(
self
)
->
str
:
return
"
FastChat
"
def
_
chain
_type
(
self
)
->
str
:
return
"
LLamaLLMChain
"
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
@property
def
_history_len
(
self
)
->
int
:
return
self
.
history_len
def
input_keys
(
self
)
->
List
[
str
]
:
"""Will be whatever keys the prompt expects.
def
set_history_len
(
self
,
history_len
:
int
=
10
)
->
None
:
self
.
history_len
=
history_len
:meta private:
"""
return
[
self
.
prompt_key
]
@property
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
:meta private:
"""
return
[
self
.
output_key
]
@property
def
_api_key
(
self
)
->
str
:
...
...
@@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
def
call_model_name
(
self
,
model_name
):
self
.
model_name
=
model_name
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
try
:
import
openai
# Not support yet
# openai.api_key = "EMPTY"
openai
.
key
=
self
.
api_key
openai
.
api_base
=
self
.
api_base_url
except
ImportError
:
raise
ValueError
(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
# create a chat completion
completion
=
openai
.
ChatCompletion
.
create
(
model
=
self
.
model_name
,
messages
=
self
.
build_message_list
(
prompt
)
)
print
(
f
"response:{completion.choices[0].message.content}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
return
completion
.
choices
[
0
]
.
message
.
content
# 将历史对话数组转换为文本格式
def
build_message_list
(
self
,
query
)
->
Collection
[
Dict
[
str
,
str
]]:
build_message_list
:
Collection
[
Dict
[
str
,
str
]]
=
[]
history
=
self
.
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[]
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
old_query
system_build_message
=
_build_message_template
()
system_build_message
[
'role'
]
=
'system'
system_build_message
[
'content'
]
=
response
build_message_list
.
append
(
user_build_message
)
build_message_list
.
append
(
system_build_message
)
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
query
build_message_list
.
append
(
user_build_message
)
return
build_message_list
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
try
:
import
openai
# Not support yet
# openai.api_key = "EMPTY"
...
...
@@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# create a chat completion
completion
=
openai
.
ChatCompletion
.
create
(
model
=
self
.
model_name
,
messages
=
self
.
build_message_list
(
prompt
)
messages
=
build_message_list
(
prompt
)
)
print
(
f
"response:{completion.choices[0].message.content}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
history
+=
[[
prompt
,
completion
.
choices
[
0
]
.
message
.
content
]]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
completion
.
choices
[
0
]
.
message
.
content
}
yield
answer_result
generate_with_callback
(
answer_result
)
models/llama_llm.py
浏览文件 @
760abab1
from
abc
import
ABC
from
langchain.llms.base
import
LLM
import
random
import
torch
import
transformers
from
abc
import
ABC
from
langchain.chains.base
import
Chain
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
,
Union
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
typing
import
Optional
,
List
,
Dict
,
Any
,
Union
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
)
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
def
__call__
(
self
,
input_ids
:
Union
[
torch
.
LongTensor
,
list
],
scores
:
Union
[
torch
.
FloatTensor
,
list
])
->
torch
.
FloatTensor
:
def
__call__
(
self
,
input_ids
:
Union
[
torch
.
LongTensor
,
list
],
scores
:
Union
[
torch
.
FloatTensor
,
list
])
->
torch
.
FloatTensor
:
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
input_ids
=
torch
.
tensor
(
input_ids
)
if
isinstance
(
input_ids
,
list
)
else
input_ids
scores
=
torch
.
tensor
(
scores
)
if
isinstance
(
scores
,
list
)
else
scores
input_ids
=
torch
.
tensor
(
input_ids
)
if
isinstance
(
input_ids
,
list
)
else
input_ids
scores
=
torch
.
tensor
(
scores
)
if
isinstance
(
scores
,
list
)
else
scores
if
torch
.
isnan
(
scores
)
.
any
()
or
torch
.
isinf
(
scores
)
.
any
():
scores
.
zero_
()
scores
[
...
,
5
]
=
5e4
return
scores
class
LLamaLLM
(
BaseAnswer
,
LLM
,
ABC
):
class
LLamaLLM
Chain
(
BaseAnswer
,
Chain
,
ABC
):
checkPoint
:
LoaderCheckPoint
=
None
# history = []
history_len
:
int
=
3
...
...
@@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
min_length
:
int
=
0
logits_processor
:
LogitsProcessorList
=
None
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
None
eos_token_id
:
Optional
[
int
]
=
[
2
]
state
:
object
=
{
'max_new_tokens'
:
50
,
'seed'
:
1
,
'temperature'
:
0
,
'top_p'
:
0.1
,
'top_k'
:
40
,
'typical_p'
:
1
,
'repetition_penalty'
:
1.2
,
'encoder_repetition_penalty'
:
1
,
'no_repeat_ngram_size'
:
0
,
'min_length'
:
0
,
'penalty_alpha'
:
0
,
'num_beams'
:
1
,
'length_penalty'
:
1
,
'early_stopping'
:
False
,
'add_bos_token'
:
True
,
'ban_eos_token'
:
False
,
'truncation_length'
:
2048
,
'custom_stopping_strings'
:
''
,
'cpu_memory'
:
0
,
'auto_devices'
:
False
,
'disk'
:
False
,
'cpu'
:
False
,
'bf16'
:
False
,
'load_in_8bit'
:
False
,
'wbits'
:
'None'
,
'groupsize'
:
'None'
,
'model_type'
:
'None'
,
'pre_layer'
:
0
,
'gpu_memory_0'
:
0
}
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
@property
def
_llm_type
(
self
)
->
str
:
return
"LLamaLLM"
def
_chain_type
(
self
)
->
str
:
return
"LLamaLLMChain"
@property
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return
[
self
.
prompt_key
]
@property
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
:meta private:
"""
return
[
self
.
output_key
]
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
...
...
@@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
formatted_history
+=
"### Human:{}
\n
### Assistant:"
.
format
(
query
)
return
formatted_history
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
):
"""
预生成注意力掩码和 输入序列中每个位置的索引的张量
# TODO 没有思路
:return:
"""
mask_positions
=
torch
.
zeros
((
1
,
input_ids
.
shape
[
1
]),
dtype
=
input_ids
.
dtype
)
.
to
(
self
.
checkPoint
.
model
.
device
)
attention_mask
=
self
.
get_masks
(
input_ids
,
input_ids
.
device
)
position_ids
=
self
.
get_position_ids
(
input_ids
,
device
=
input_ids
.
device
,
mask_positions
=
mask_positions
)
return
input_ids
,
position_ids
,
attention_mask
@property
def
_history_len
(
self
)
->
int
:
return
self
.
history_len
def
set_history_len
(
self
,
history_len
:
int
=
10
)
->
None
:
self
.
history_len
=
history_len
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
# Create the StoppingCriteriaList with the stopping strings
self
.
stopping_criteria
=
transformers
.
StoppingCriteriaList
()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue
=
AnswerResultQueueSentinelTokenListenerQueue
()
self
.
stopping_criteria
.
append
(
listenerQueue
)
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
soft_prompt
=
self
.
history_to_text
(
query
=
prompt
,
history
=
history
)
if
self
.
logits_processor
is
None
:
self
.
logits_processor
=
LogitsProcessorList
()
self
.
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
...
...
@@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
"logits_processor"
:
self
.
logits_processor
}
# 向量转换
input_ids
=
self
.
encode
(
prompt
,
add_bos_token
=
self
.
state
[
'add_bos_token'
],
truncation_length
=
self
.
max_new_tokens
)
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
input_ids
=
self
.
encode
(
soft_prompt
,
add_bos_token
=
self
.
checkPoint
.
tokenizer
.
add_bos_token
,
truncation_length
=
self
.
max_new_tokens
)
gen_kwargs
.
update
({
'inputs'
:
input_ids
})
# 注意力掩码
# gen_kwargs.update({'attention_mask': attention_mask})
# gen_kwargs.update({'position_ids': position_ids})
if
self
.
stopping_criteria
is
None
:
self
.
stopping_criteria
=
transformers
.
StoppingCriteriaList
()
# 观测输出
gen_kwargs
.
update
({
'stopping_criteria'
:
self
.
stopping_criteria
})
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
...
...
@@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
if
"llama_cpp"
in
self
.
checkPoint
.
model
.
__str__
():
import
inspect
common_kwargs_keys
=
set
(
inspect
.
getfullargspec
(
self
.
checkPoint
.
model
.
generate
)
.
args
)
&
set
(
gen_kwargs
.
keys
())
common_kwargs
=
{
key
:
gen_kwargs
[
key
]
for
key
in
common_kwargs_keys
}
#? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
#?为什么会不支持GPU呢,不应该啊?
output_ids
=
torch
.
tensor
([
list
(
self
.
checkPoint
.
model
.
generate
(
input_id_i
.
cpu
(),
**
common_kwargs
))
for
input_id_i
in
input_ids
])
common_kwargs_keys
=
set
(
inspect
.
getfullargspec
(
self
.
checkPoint
.
model
.
generate
)
.
args
)
&
set
(
gen_kwargs
.
keys
())
common_kwargs
=
{
key
:
gen_kwargs
[
key
]
for
key
in
common_kwargs_keys
}
# ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
# ?为什么会不支持GPU呢,不应该啊?
output_ids
=
torch
.
tensor
(
[
list
(
self
.
checkPoint
.
model
.
generate
(
input_id_i
.
cpu
(),
**
common_kwargs
))
for
input_id_i
in
input_ids
])
else
:
output_ids
=
self
.
checkPoint
.
model
.
generate
(
**
gen_kwargs
)
...
...
@@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
reply
=
self
.
decode
(
output_ids
[
0
][
-
new_tokens
:])
print
(
f
"response:{reply}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
return
reply
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
softprompt
=
self
.
history_to_text
(
prompt
,
history
=
history
)
response
=
self
.
_call
(
prompt
=
softprompt
,
stop
=
[
'
\n
###'
])
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
+
[[
prompt
,
response
]]
answer_result
.
llm_output
=
{
"answer"
:
response
}
yield
answer_result
history
+=
[[
prompt
,
reply
]]
answer_result
.
history
=
history
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
answer_result
.
llm_output
=
{
"answer"
:
reply
}
generate_with_callback
(
answer_result
)
models/loader/args.py
浏览文件 @
760abab1
import
argparse
import
os
from
configs.model_config
import
*
...
...
@@ -45,7 +46,6 @@ parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the m
parser
.
add_argument
(
"--lora-dir"
,
type
=
str
,
default
=
LORA_DIR
,
help
=
"Path to directory with all the loras"
)
parser
.
add_argument
(
'--use-ptuning-v2'
,
type
=
str
,
default
=
False
,
help
=
"whether use ptuning-v2 checkpoint"
)
parser
.
add_argument
(
"--ptuning-dir"
,
type
=
str
,
default
=
PTUNING_DIR
,
help
=
"the dir of ptuning-v2 checkpoint"
)
# Accelerate/transformers
parser
.
add_argument
(
'--load-in-8bit'
,
action
=
'store_true'
,
default
=
LOAD_IN_8BIT
,
help
=
'Load the model with 8-bit precision.'
)
...
...
models/loader/loader.py
浏览文件 @
760abab1
...
...
@@ -20,6 +20,7 @@ class LoaderCheckPoint:
no_remote_model
:
bool
=
False
# 模型名称
model_name
:
str
=
None
pretrained_model_name
:
str
=
None
tokenizer
:
object
=
None
# 模型全路径
model_path
:
str
=
None
...
...
@@ -67,48 +68,49 @@ class LoaderCheckPoint:
self
.
load_in_8bit
=
params
.
get
(
'load_in_8bit'
,
False
)
self
.
bf16
=
params
.
get
(
'bf16'
,
False
)
def
_load_model_config
(
self
,
model_name
):
def
_load_model_config
(
self
):
if
self
.
model_path
:
self
.
model_path
=
re
.
sub
(
"
\
s"
,
""
,
self
.
model_path
)
self
.
model_path
=
re
.
sub
(
"
\
s"
,
""
,
self
.
model_path
)
checkpoint
=
Path
(
f
'{self.model_path}'
)
else
:
if
not
self
.
no_remote_model
:
checkpoint
=
model_name
else
:
if
self
.
no_remote_model
:
raise
ValueError
(
"本地模型local_model_path未配置路径"
)
else
:
checkpoint
=
self
.
pretrained_model_name
print
(
f
"load_model_config {checkpoint}..."
)
try
:
model_config
=
AutoConfig
.
from_pretrained
(
checkpoint
,
trust_remote_code
=
True
)
return
model_config
except
Exception
as
e
:
print
(
e
)
return
checkpoint
def
_load_model
(
self
,
model_name
):
def
_load_model
(
self
):
"""
加载自定义位置的model
:param model_name:
:return:
"""
print
(
f
"Loading {model_name}..."
)
t0
=
time
.
time
()
if
self
.
model_path
:
self
.
model_path
=
re
.
sub
(
"
\
s"
,
""
,
self
.
model_path
)
self
.
model_path
=
re
.
sub
(
"
\
s"
,
""
,
self
.
model_path
)
checkpoint
=
Path
(
f
'{self.model_path}'
)
else
:
if
not
self
.
no_remote_model
:
checkpoint
=
model_name
else
:
if
self
.
no_remote_model
:
raise
ValueError
(
"本地模型local_model_path未配置路径"
)
else
:
checkpoint
=
self
.
pretrained_model_name
print
(
f
"Loading {checkpoint}..."
)
self
.
is_llamacpp
=
len
(
list
(
Path
(
f
'{checkpoint}'
)
.
glob
(
'ggml*.bin'
)))
>
0
if
'chatglm'
in
model_name
.
lower
()
or
"chatyuan"
in
model_name
.
lower
():
if
'chatglm'
in
self
.
model_name
.
lower
()
or
"chatyuan"
in
self
.
model_name
.
lower
():
LoaderClass
=
AutoModel
else
:
LoaderClass
=
AutoModelForCausalLM
...
...
@@ -138,7 +140,7 @@ class LoaderCheckPoint:
torch_dtype
=
torch
.
bfloat16
if
self
.
bf16
else
torch
.
float16
,
trust_remote_code
=
True
)
.
half
()
.
to
(
self
.
llm_device
)
else
:
from
accelerate
import
dispatch_model
,
infer_auto_device_map
from
accelerate
import
dispatch_model
,
infer_auto_device_map
model
=
LoaderClass
.
from_pretrained
(
checkpoint
,
config
=
self
.
model_config
,
...
...
@@ -146,10 +148,10 @@ class LoaderCheckPoint:
trust_remote_code
=
True
)
.
half
()
# 可传入device_map自定义每张卡的部署情况
if
self
.
device_map
is
None
:
if
'chatglm'
in
model_name
.
lower
():
if
'chatglm'
in
self
.
model_name
.
lower
():
self
.
device_map
=
self
.
chatglm_auto_configure_device_map
(
num_gpus
)
elif
'moss'
in
model_name
.
lower
():
self
.
device_map
=
self
.
moss_auto_configure_device_map
(
num_gpus
,
model_name
)
elif
'moss'
in
self
.
model_name
.
lower
():
self
.
device_map
=
self
.
moss_auto_configure_device_map
(
num_gpus
,
checkpoint
)
else
:
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
...
...
@@ -166,9 +168,9 @@ class LoaderCheckPoint:
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
# 实测在bloom模型上如此
# self.device_map = infer_auto_device_map(model,
# dtype=torch.int8,
# no_split_module_classes=model._no_split_modules)
# self.device_map = infer_auto_device_map(model,
# dtype=torch.int8,
# no_split_module_classes=model._no_split_modules)
model
=
dispatch_model
(
model
,
device_map
=
self
.
device_map
)
else
:
...
...
@@ -202,7 +204,7 @@ class LoaderCheckPoint:
# tokenizer = model.tokenizer
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
#* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
#
* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
model_name
)
return
model
,
tokenizer
...
...
@@ -231,7 +233,7 @@ class LoaderCheckPoint:
llm_int8_enable_fp32_cpu_offload
=
False
)
with
init_empty_weights
():
model
=
LoaderClass
.
from_config
(
self
.
model_config
,
trust_remote_code
=
True
)
model
=
LoaderClass
.
from_config
(
self
.
model_config
,
trust_remote_code
=
True
)
model
.
tie_weights
()
if
self
.
device_map
is
not
None
:
params
[
'device_map'
]
=
self
.
device_map
...
...
@@ -321,7 +323,7 @@ class LoaderCheckPoint:
return
device_map
def
moss_auto_configure_device_map
(
self
,
num_gpus
:
int
,
model_name
)
->
Dict
[
str
,
int
]:
def
moss_auto_configure_device_map
(
self
,
num_gpus
:
int
,
checkpoint
)
->
Dict
[
str
,
int
]:
try
:
from
accelerate
import
init_empty_weights
...
...
@@ -336,16 +338,6 @@ class LoaderCheckPoint:
"`pip install bitsandbytes``pip install accelerate`."
)
from
exc
if
self
.
model_path
:
checkpoint
=
Path
(
f
'{self.model_path}'
)
else
:
if
not
self
.
no_remote_model
:
checkpoint
=
model_name
else
:
raise
ValueError
(
"本地模型local_model_path未配置路径"
)
cls
=
get_class_from_dynamic_module
(
class_reference
=
"fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM"
,
pretrained_model_name_or_path
=
checkpoint
)
...
...
@@ -452,7 +444,7 @@ class LoaderCheckPoint:
def
reload_model
(
self
):
self
.
unload_model
()
self
.
model_config
=
self
.
_load_model_config
(
self
.
model_name
)
self
.
model_config
=
self
.
_load_model_config
()
if
self
.
use_ptuning_v2
:
try
:
...
...
@@ -464,7 +456,7 @@ class LoaderCheckPoint:
except
Exception
as
e
:
print
(
"加载PrefixEncoder config.json失败"
)
self
.
model
,
self
.
tokenizer
=
self
.
_load_model
(
self
.
model_name
)
self
.
model
,
self
.
tokenizer
=
self
.
_load_model
()
if
self
.
lora
:
self
.
_add_lora_to_model
([
self
.
lora
])
...
...
models/moss_llm.py
浏览文件 @
760abab1
from
abc
import
ABC
from
langchain.llms.base
import
LLM
from
typing
import
Optional
,
List
from
langchain.chains.base
import
Chain
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
,
Union
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
)
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
import
torch
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
META_INSTRUCTION
=
\
"""You are an AI assistant whose name is MOSS.
...
...
@@ -20,41 +28,65 @@ META_INSTRUCTION = \
Capabilities and tools that MOSS can possess.
"""
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
class
MOSSLLM
(
BaseAnswer
,
LLM
,
ABC
):
class
MOSSLLM
Chain
(
BaseAnswer
,
Chain
,
ABC
):
max_token
:
int
=
2048
temperature
:
float
=
0.7
top_p
=
0.8
# history = []
checkPoint
:
LoaderCheckPoint
=
None
history_len
:
int
=
10
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
@property
def
_
llm
_type
(
self
)
->
str
:
return
"MOSS"
def
_
chain
_type
(
self
)
->
str
:
return
"MOSS
LLMChain
"
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return
[
self
.
prompt_key
]
@property
def
_history_len
(
self
)
->
int
:
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
return
self
.
history_len
:meta private:
"""
return
[
self
.
output_key
]
def
set_history_len
(
self
,
history_len
:
int
)
->
None
:
self
.
history_len
=
history_len
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
pass
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
if
len
(
history
)
>
0
:
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[]
prompt_w_history
=
str
(
history
)
...
...
@@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
num_return_sequences
=
1
,
eos_token_id
=
106068
,
pad_token_id
=
self
.
checkPoint
.
tokenizer
.
pad_token_id
)
response
=
self
.
checkPoint
.
tokenizer
.
decode
(
outputs
[
0
][
inputs
.
input_ids
.
shape
[
1
]:],
skip_special_tokens
=
True
)
response
=
self
.
checkPoint
.
tokenizer
.
decode
(
outputs
[
0
][
inputs
.
input_ids
.
shape
[
1
]:],
skip_special_tokens
=
True
)
self
.
checkPoint
.
clear_torch_cache
()
history
+=
[[
prompt
,
response
]]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
yield
answer_result
generate_with_callback
(
answer_result
)
models/shared.py
浏览文件 @
760abab1
...
...
@@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if
use_ptuning_v2
:
loaderCheckPoint
.
use_ptuning_v2
=
use_ptuning_v2
# 如果指定了参数,则使用参数的配置
if
llm_model
:
llm_model_info
=
llm_model_dict
[
llm_model
]
if
loaderCheckPoint
.
no_remote_model
:
loaderCheckPoint
.
model_name
=
llm_model_info
[
'name'
]
else
:
loaderCheckPoint
.
model_name
=
llm_model_info
[
'pretrained_model_name'
]
loaderCheckPoint
.
pretrained_model_name
=
llm_model_info
[
'pretrained_model_name'
]
loaderCheckPoint
.
model_path
=
llm_model_info
[
"local_model_path"
]
...
...
test/models/test_fastchat_openai_llm.py
deleted
100644 → 0
浏览文件 @
3d082bf5
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
'/../../'
)
import
asyncio
from
argparse
import
Namespace
from
models.loader.args
import
parser
from
models.loader
import
LoaderCheckPoint
import
models.shared
as
shared
async
def
dispatch
(
args
:
Namespace
):
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
history
=
[
(
"which city is this?"
,
"tokyo"
),
(
"why?"
,
"she's japanese"
),
]
for
answer_result
in
llm_model_ins
.
generatorAnswer
(
prompt
=
"你好? "
,
history
=
history
,
streaming
=
False
):
resp
=
answer_result
.
llm_output
[
"answer"
]
print
(
resp
)
if
__name__
==
'__main__'
:
args
=
None
args
=
parser
.
parse_args
(
args
=
[
'--model-dir'
,
'/media/checkpoint/'
,
'--model'
,
'fastchat-chatglm-6b'
,
'--no-remote-model'
])
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
run_until_complete
(
dispatch
(
args
))
webui.py
浏览文件 @
760abab1
...
...
@@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield
history
+
[[
query
,
"请选择知识库后进行测试,当前未选择知识库。"
]],
""
else
:
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
query
,
history
=
history
,
streaming
=
streaming
):
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
query
,
"history"
:
history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
[
-
1
][
-
1
]
=
resp
...
...
@@ -101,11 +104,12 @@ def init_model():
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
try
:
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
)
generator
=
local_doc_qa
.
llm
.
generatorAnswer
(
"你好"
)
for
answer_result
in
generator
:
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
"你好"
,
"history"
:
[],
"streaming"
:
False
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
print
(
answer_result
.
llm_output
)
reply
=
"""模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger
.
info
(
reply
)
...
...
@@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def
get_vector_store
(
vs_id
,
files
,
sentence_size
,
history
,
one_conent
,
one_content_segmentation
):
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
filelist
=
[]
if
local_doc_qa
.
llm
and
local_doc_qa
.
embeddings
:
if
local_doc_qa
.
llm
_model_chain
and
local_doc_qa
.
embeddings
:
if
isinstance
(
files
,
list
):
for
file
in
files
:
filename
=
os
.
path
.
split
(
file
.
name
)[
-
1
]
...
...
@@ -165,7 +169,7 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
def
change_vs_name_input
(
vs_id
,
history
):
if
vs_id
==
"新建知识库"
:
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
False
),
None
,
history
,
\
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
False
),
None
,
history
,
\
gr
.
update
(
choices
=
[]),
gr
.
update
(
visible
=
False
)
else
:
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
...
...
@@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history):
def
add_vs_name
(
vs_name
,
chatbot
):
if
vs_name
is
None
or
vs_name
.
strip
()
==
""
:
if
vs_name
is
None
or
vs_name
.
strip
()
==
""
:
vs_status
=
"知识库名称不能为空,请重新填写知识库名称"
chatbot
=
chatbot
+
[[
None
,
vs_status
]]
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
...
...
@@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
def
refresh_vs_list
():
return
gr
.
update
(
choices
=
get_vs_list
()),
gr
.
update
(
choices
=
get_vs_list
())
def
delete_file
(
vs_id
,
files_to_delete
,
chatbot
):
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
content_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"content"
)
...
...
@@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
rested_files
=
local_doc_qa
.
list_file_from_vector_store
(
vs_path
)
if
"fail"
in
status
:
vs_status
=
"文件删除失败。"
elif
len
(
rested_files
)
>
0
:
elif
len
(
rested_files
)
>
0
:
vs_status
=
"文件删除成功。"
else
:
vs_status
=
f
"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
logger
.
info
(
","
.
join
(
files_to_delete
)
+
vs_status
)
logger
.
info
(
","
.
join
(
files_to_delete
)
+
vs_status
)
chatbot
=
chatbot
+
[[
None
,
vs_status
]]
return
gr
.
update
(
choices
=
local_doc_qa
.
list_file_from_vector_store
(
vs_path
),
value
=
[]),
chatbot
...
...
@@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
status
=
f
"成功删除知识库{vs_id}"
logger
.
info
(
status
)
chatbot
=
chatbot
+
[[
None
,
status
]]
return
gr
.
update
(
choices
=
get_vs_list
(),
value
=
get_vs_list
()[
0
]),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
\
return
gr
.
update
(
choices
=
get_vs_list
(),
value
=
get_vs_list
()[
0
]),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
\
gr
.
update
(
visible
=
False
),
chatbot
,
gr
.
update
(
visible
=
False
)
except
Exception
as
e
:
logger
.
error
(
e
)
...
...
@@ -333,7 +339,8 @@ default_theme_args = dict(
with
gr
.
Blocks
(
css
=
block_css
,
theme
=
gr
.
themes
.
Default
(
**
default_theme_args
))
as
demo
:
vs_path
,
file_status
,
model_status
=
gr
.
State
(
os
.
path
.
join
(
KB_ROOT_PATH
,
get_vs_list
()[
0
],
"vector_store"
)
if
len
(
get_vs_list
())
>
1
else
""
),
gr
.
State
(
""
),
gr
.
State
(
os
.
path
.
join
(
KB_ROOT_PATH
,
get_vs_list
()[
0
],
"vector_store"
)
if
len
(
get_vs_list
())
>
1
else
""
),
gr
.
State
(
""
),
gr
.
State
(
model_status
)
gr
.
Markdown
(
webui_title
)
with
gr
.
Tab
(
"对话"
):
...
...
webui_st.py
浏览文件 @
760abab1
...
...
@@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield
history
+
[[
query
,
"请选择知识库后进行测试,当前未选择知识库。"
]],
""
else
:
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
query
,
history
=
history
,
streaming
=
streaming
):
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
query
,
"history"
:
history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
[
-
1
][
-
1
]
=
resp
+
(
...
...
@@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
args_dict
.
update
(
model
=
llm_model
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
try
:
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
,
embedding_model
=
embedding_model
)
generator
=
local_doc_qa
.
llm
.
generatorAnswer
(
"你好"
)
for
answer_result
in
generator
:
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
"你好"
,
"history"
:
[],
"streaming"
:
False
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
print
(
answer_result
.
llm_output
)
reply
=
"""模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger
.
info
(
reply
)
...
...
@@ -468,7 +470,7 @@ with st.sidebar:
top_k
=
st
.
slider
(
'向量匹配数量'
,
1
,
20
,
VECTOR_SEARCH_TOP_K
)
history_len
=
st
.
slider
(
'LLM对话轮数'
,
1
,
50
,
LLM_HISTORY_LEN
)
# 也许要跟知识库分开设置
local_doc_qa
.
llm
.
set_history_len
(
history_len
)
#
local_doc_qa.llm.set_history_len(history_len)
chunk_conent
=
st
.
checkbox
(
'启用上下文关联'
,
False
)
st
.
text
(
''
)
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论