Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
J
jinchat-server
概览
概览
详情
活动
周期分析
版本库
存储库
文件
提交
分支
标签
贡献者
分支图
比较
统计图
问题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
aigc-pioneer
jinchat-server
Commits
c5bc2178
提交
c5bc2178
authored
7月 12, 2023
作者:
glide-the
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
修改模型生成的调用方式,兼容Chain调用
修改模型切换的bug
上级
ca13ab81
显示空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
525 行增加
和
348 行删除
+525
-348
.gitignore
.gitignore
+0
-0
api.py
api.py
+4
-3
local_doc_qa.py
chains/local_doc_qa.py
+13
-9
model_config.py
configs/model_config.py
+25
-19
__init__.py
models/__init__.py
+4
-4
__init__.py
models/base/__init__.py
+5
-3
base.py
models/base/base.py
+151
-13
chatglm_llm.py
models/chatglm_llm.py
+68
-31
fastchat_openai_llm.py
models/fastchat_openai_llm.py
+71
-65
llama_llm.py
models/llama_llm.py
+75
-84
loader.py
models/loader/loader.py
+28
-36
moss_llm.py
models/moss_llm.py
+52
-21
shared.py
models/shared.py
+2
-3
test_fastchat_openai_llm.py
test/models/test_fastchat_openai_llm.py
+0
-39
webui.py
webui.py
+19
-12
webui_st.py
webui_st.py
+8
-6
没有找到文件。
.gitignore
浏览文件 @
c5bc2178
api.py
浏览文件 @
c5bc2178
...
@@ -384,8 +384,10 @@ async def chat(
...
@@ -384,8 +384,10 @@ async def chat(
],
],
),
),
):
):
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
question
,
history
=
history
,
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
streaming
=
True
):
{
"prompt"
:
question
,
"history"
:
history
,
"streaming"
:
True
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
=
answer_result
.
history
pass
pass
...
@@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
...
@@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
global
local_doc_qa
global
local_doc_qa
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
app
=
FastAPI
()
app
=
FastAPI
()
# Add CORS middleware to allow all origins
# Add CORS middleware to allow all origins
...
...
chains/local_doc_qa.py
浏览文件 @
c5bc2178
...
@@ -18,6 +18,7 @@ from agent import bing_search
...
@@ -18,6 +18,7 @@ from agent import bing_search
from
langchain.docstore.document
import
Document
from
langchain.docstore.document
import
Document
from
functools
import
lru_cache
from
functools
import
lru_cache
from
textsplitter.zh_title_enhance
import
zh_title_enhance
from
textsplitter.zh_title_enhance
import
zh_title_enhance
from
langchain.chains.base
import
Chain
# patch HuggingFaceEmbeddings to make it hashable
# patch HuggingFaceEmbeddings to make it hashable
...
@@ -119,7 +120,7 @@ def search_result2docs(search_results):
...
@@ -119,7 +120,7 @@ def search_result2docs(search_results):
class
LocalDocQA
:
class
LocalDocQA
:
llm
:
BaseAnswer
=
None
llm
_model_chain
:
Chain
=
None
embeddings
:
object
=
None
embeddings
:
object
=
None
top_k
:
int
=
VECTOR_SEARCH_TOP_K
top_k
:
int
=
VECTOR_SEARCH_TOP_K
chunk_size
:
int
=
CHUNK_SIZE
chunk_size
:
int
=
CHUNK_SIZE
...
@@ -129,10 +130,10 @@ class LocalDocQA:
...
@@ -129,10 +130,10 @@ class LocalDocQA:
def
init_cfg
(
self
,
def
init_cfg
(
self
,
embedding_model
:
str
=
EMBEDDING_MODEL
,
embedding_model
:
str
=
EMBEDDING_MODEL
,
embedding_device
=
EMBEDDING_DEVICE
,
embedding_device
=
EMBEDDING_DEVICE
,
llm_model
:
BaseAnswer
=
None
,
llm_model
:
Chain
=
None
,
top_k
=
VECTOR_SEARCH_TOP_K
,
top_k
=
VECTOR_SEARCH_TOP_K
,
):
):
self
.
llm
=
llm_model
self
.
llm
_model_chain
=
llm_model
self
.
embeddings
=
HuggingFaceEmbeddings
(
model_name
=
embedding_model_dict
[
embedding_model
],
self
.
embeddings
=
HuggingFaceEmbeddings
(
model_name
=
embedding_model_dict
[
embedding_model
],
model_kwargs
=
{
'device'
:
embedding_device
})
model_kwargs
=
{
'device'
:
embedding_device
})
self
.
top_k
=
top_k
self
.
top_k
=
top_k
...
@@ -236,8 +237,10 @@ class LocalDocQA:
...
@@ -236,8 +237,10 @@ class LocalDocQA:
else
:
else
:
prompt
=
query
prompt
=
query
for
answer_result
in
self
.
llm
.
generatorAnswer
(
prompt
=
prompt
,
history
=
chat_history
,
answer_result_stream_result
=
self
.
llm_model_chain
(
streaming
=
streaming
):
{
"prompt"
:
prompt
,
"history"
:
chat_history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
=
answer_result
.
history
history
[
-
1
][
0
]
=
query
history
[
-
1
][
0
]
=
query
...
@@ -276,8 +279,10 @@ class LocalDocQA:
...
@@ -276,8 +279,10 @@ class LocalDocQA:
result_docs
=
search_result2docs
(
results
)
result_docs
=
search_result2docs
(
results
)
prompt
=
generate_prompt
(
result_docs
,
query
)
prompt
=
generate_prompt
(
result_docs
,
query
)
for
answer_result
in
self
.
llm
.
generatorAnswer
(
prompt
=
prompt
,
history
=
chat_history
,
answer_result_stream_result
=
self
.
llm_model_chain
(
streaming
=
streaming
):
{
"prompt"
:
prompt
,
"history"
:
chat_history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
=
answer_result
.
history
history
[
-
1
][
0
]
=
query
history
[
-
1
][
0
]
=
query
...
@@ -296,7 +301,7 @@ class LocalDocQA:
...
@@ -296,7 +301,7 @@ class LocalDocQA:
def
update_file_from_vector_store
(
self
,
def
update_file_from_vector_store
(
self
,
filepath
:
str
or
List
[
str
],
filepath
:
str
or
List
[
str
],
vs_path
,
vs_path
,
docs
:
List
[
Document
],):
docs
:
List
[
Document
],
):
vector_store
=
load_vector_store
(
vs_path
,
self
.
embeddings
)
vector_store
=
load_vector_store
(
vs_path
,
self
.
embeddings
)
status
=
vector_store
.
update_doc
(
filepath
,
docs
)
status
=
vector_store
.
update_doc
(
filepath
,
docs
)
return
status
return
status
...
@@ -320,7 +325,6 @@ if __name__ == "__main__":
...
@@ -320,7 +325,6 @@ if __name__ == "__main__":
args_dict
=
vars
(
args
)
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
local_doc_qa
=
LocalDocQA
()
local_doc_qa
=
LocalDocQA
()
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
)
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
)
...
...
configs/model_config.py
浏览文件 @
c5bc2178
...
@@ -37,61 +37,67 @@ llm_model_dict = {
...
@@ -37,61 +37,67 @@ llm_model_dict = {
"name"
:
"chatglm-6b-int4-qe"
,
"name"
:
"chatglm-6b-int4-qe"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int4-qe"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int4-qe"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
},
"chatglm-6b-int4"
:
{
"chatglm-6b-int4"
:
{
"name"
:
"chatglm-6b-int4"
,
"name"
:
"chatglm-6b-int4"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int4"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int4"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
},
"chatglm-6b-int8"
:
{
"chatglm-6b-int8"
:
{
"name"
:
"chatglm-6b-int8"
,
"name"
:
"chatglm-6b-int8"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int8"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int8"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
},
"chatglm-6b"
:
{
"chatglm-6b"
:
{
"name"
:
"chatglm-6b"
,
"name"
:
"chatglm-6b"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
},
"chatglm2-6b"
:
{
"chatglm2-6b"
:
{
"name"
:
"chatglm2-6b"
,
"name"
:
"chatglm2-6b"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
},
"chatglm2-6b-int4"
:
{
"chatglm2-6b-int4"
:
{
"name"
:
"chatglm2-6b-int4"
,
"name"
:
"chatglm2-6b-int4"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b-int4"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b-int4"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
},
"chatglm2-6b-int8"
:
{
"chatglm2-6b-int8"
:
{
"name"
:
"chatglm2-6b-int8"
,
"name"
:
"chatglm2-6b-int8"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b-int8"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b-int8"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
},
"chatyuan"
:
{
"chatyuan"
:
{
"name"
:
"chatyuan"
,
"name"
:
"chatyuan"
,
"pretrained_model_name"
:
"ClueAI/ChatYuan-large-v2"
,
"pretrained_model_name"
:
"ClueAI/ChatYuan-large-v2"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
},
"moss"
:
{
"moss"
:
{
"name"
:
"moss"
,
"name"
:
"moss"
,
"pretrained_model_name"
:
"fnlp/moss-moon-003-sft"
,
"pretrained_model_name"
:
"fnlp/moss-moon-003-sft"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
},
"vicuna-13b-hf"
:
{
"vicuna-13b-hf"
:
{
"name"
:
"vicuna-13b-hf"
,
"name"
:
"vicuna-13b-hf"
,
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"LLamaLLM"
"provides"
:
"LLamaLLMChain"
},
"vicuna-7b-hf"
:
{
"name"
:
"vicuna-13b-hf"
,
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"provides"
:
"LLamaLLMChain"
},
},
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
...
@@ -101,7 +107,7 @@ llm_model_dict = {
...
@@ -101,7 +107,7 @@ llm_model_dict = {
"name"
:
"bloomz-7b1"
,
"name"
:
"bloomz-7b1"
,
"pretrained_model_name"
:
"bigscience/bloomz-7b1"
,
"pretrained_model_name"
:
"bigscience/bloomz-7b1"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
},
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
...
@@ -110,14 +116,14 @@ llm_model_dict = {
...
@@ -110,14 +116,14 @@ llm_model_dict = {
"name"
:
"bloom-3b"
,
"name"
:
"bloom-3b"
,
"pretrained_model_name"
:
"bigscience/bloom-3b"
,
"pretrained_model_name"
:
"bigscience/bloom-3b"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
},
"baichuan-7b"
:
{
"baichuan-7b"
:
{
"name"
:
"baichuan-7b"
,
"name"
:
"baichuan-7b"
,
"pretrained_model_name"
:
"baichuan-inc/baichuan-7B"
,
"pretrained_model_name"
:
"baichuan-inc/baichuan-7B"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
},
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
"ggml-vicuna-13b-1.1-q5"
:
{
"ggml-vicuna-13b-1.1-q5"
:
{
...
@@ -131,7 +137,7 @@ llm_model_dict = {
...
@@ -131,7 +137,7 @@ llm_model_dict = {
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
"local_model_path"
:
f
'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/'''
,
"local_model_path"
:
f
'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/'''
,
"provides"
:
"LLamaLLM"
"provides"
:
"LLamaLLM
Chain
"
},
},
# 通过 fastchat 调用的模型请参考如下格式
# 通过 fastchat 调用的模型请参考如下格式
...
@@ -139,7 +145,7 @@ llm_model_dict = {
...
@@ -139,7 +145,7 @@ llm_model_dict = {
"name"
:
"chatglm-6b"
,
# "name"修改为fastchat服务中的"model_name"
"name"
:
"chatglm-6b"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"chatglm-6b"
,
"pretrained_model_name"
:
"chatglm-6b"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
,
# "name"修改为fastchat服务中的"api_base_url"
"api_base_url"
:
"http://localhost:8000/v1"
,
# "name"修改为fastchat服务中的"api_base_url"
"api_key"
:
"EMPTY"
"api_key"
:
"EMPTY"
},
},
...
@@ -147,7 +153,7 @@ llm_model_dict = {
...
@@ -147,7 +153,7 @@ llm_model_dict = {
"name"
:
"chatglm2-6b"
,
# "name"修改为fastchat服务中的"model_name"
"name"
:
"chatglm2-6b"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"chatglm2-6b"
,
"pretrained_model_name"
:
"chatglm2-6b"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
# "name"修改为fastchat服务中的"api_base_url"
"api_base_url"
:
"http://localhost:8000/v1"
# "name"修改为fastchat服务中的"api_base_url"
},
},
...
@@ -156,7 +162,7 @@ llm_model_dict = {
...
@@ -156,7 +162,7 @@ llm_model_dict = {
"name"
:
"vicuna-13b-hf"
,
# "name"修改为fastchat服务中的"model_name"
"name"
:
"vicuna-13b-hf"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
,
# "name"修改为fastchat服务中的"api_base_url"
"api_base_url"
:
"http://localhost:8000/v1"
,
# "name"修改为fastchat服务中的"api_base_url"
"api_key"
:
"EMPTY"
"api_key"
:
"EMPTY"
},
},
...
@@ -171,7 +177,7 @@ llm_model_dict = {
...
@@ -171,7 +177,7 @@ llm_model_dict = {
"openai-chatgpt-3.5"
:
{
"openai-chatgpt-3.5"
:
{
"name"
:
"gpt-3.5-turbo"
,
"name"
:
"gpt-3.5-turbo"
,
"pretrained_model_name"
:
"gpt-3.5-turbo"
,
"pretrained_model_name"
:
"gpt-3.5-turbo"
,
"provides"
:
"FastChatOpenAILLM"
,
"provides"
:
"FastChatOpenAILLM
Chain
"
,
"local_model_path"
:
None
,
"local_model_path"
:
None
,
"api_base_url"
:
"https://api.openapi.com/v1"
,
"api_base_url"
:
"https://api.openapi.com/v1"
,
"api_key"
:
""
"api_key"
:
""
...
@@ -226,7 +232,7 @@ LLM_HISTORY_LEN = 3
...
@@ -226,7 +232,7 @@ LLM_HISTORY_LEN = 3
VECTOR_SEARCH_TOP_K
=
5
VECTOR_SEARCH_TOP_K
=
5
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
VECTOR_SEARCH_SCORE_THRESHOLD
=
0
VECTOR_SEARCH_SCORE_THRESHOLD
=
39
0
NLTK_DATA_PATH
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)),
"nltk_data"
)
NLTK_DATA_PATH
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)),
"nltk_data"
)
...
...
models/__init__.py
浏览文件 @
c5bc2178
from
.chatglm_llm
import
ChatGLM
from
.chatglm_llm
import
ChatGLM
LLMChain
from
.llama_llm
import
LLamaLLM
from
.llama_llm
import
LLamaLLM
Chain
from
.
moss_llm
import
MOSSLLM
from
.
fastchat_openai_llm
import
FastChatOpenAILLMChain
from
.
fastchat_openai_llm
import
FastChatOpenAILLM
from
.
moss_llm
import
MOSSLLMChain
models/base/__init__.py
浏览文件 @
c5bc2178
from
models.base.base
import
(
from
models.base.base
import
(
AnswerResult
,
AnswerResult
,
BaseAnswer
BaseAnswer
,
)
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
from
models.base.remote_rpc_model
import
(
from
models.base.remote_rpc_model
import
(
RemoteRpcModel
RemoteRpcModel
)
)
__all__
=
[
__all__
=
[
"AnswerResult"
,
"AnswerResult"
,
"BaseAnswer"
,
"BaseAnswer"
,
"RemoteRpcModel"
,
"RemoteRpcModel"
,
"AnswerResultStream"
,
"AnswerResultQueueSentinelTokenListenerQueue"
]
]
models/base/base.py
浏览文件 @
c5bc2178
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
import
traceback
import
traceback
from
collections
import
deque
from
collections
import
deque
from
queue
import
Queue
from
queue
import
Queue
from
threading
import
Thread
from
threading
import
Thread
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
models.loader
import
LoaderCheckPoint
import
torch
import
torch
import
transformers
import
transformers
from
models.loader
import
LoaderCheckPoint
class
ListenerToken
:
"""
观测结果
"""
input_ids
:
torch
.
LongTensor
_scores
:
torch
.
FloatTensor
def
__init__
(
self
,
input_ids
:
torch
.
LongTensor
,
_scores
:
torch
.
FloatTensor
):
self
.
input_ids
=
input_ids
self
.
_scores
=
_scores
class
AnswerResult
:
class
AnswerResult
:
...
@@ -16,6 +29,123 @@ class AnswerResult:
...
@@ -16,6 +29,123 @@ class AnswerResult:
"""
"""
history
:
List
[
List
[
str
]]
=
[]
history
:
List
[
List
[
str
]]
=
[]
llm_output
:
Optional
[
dict
]
=
None
llm_output
:
Optional
[
dict
]
=
None
listenerToken
:
ListenerToken
=
None
class
AnswerResultStream
:
def
__init__
(
self
,
callback_func
=
None
):
self
.
callback_func
=
callback_func
def
__call__
(
self
,
answerResult
:
AnswerResult
):
if
self
.
callback_func
is
not
None
:
self
.
callback_func
(
answerResult
)
class
AnswerResultQueueSentinelTokenListenerQueue
(
transformers
.
StoppingCriteria
):
"""
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
"""
listenerQueue
:
deque
=
deque
(
maxlen
=
1
)
def
__init__
(
self
):
transformers
.
StoppingCriteria
.
__init__
(
self
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
_scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self
.
listenerQueue
.
append
(
ListenerToken
(
input_ids
=
input_ids
,
_scores
=
_scores
))
return
False
class
Iteratorize
:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def
__init__
(
self
,
func
,
kwargs
=
{}):
self
.
mfunc
=
func
self
.
q
=
Queue
()
self
.
sentinel
=
object
()
self
.
kwargs
=
kwargs
self
.
stop_now
=
False
def
_callback
(
val
):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
结束条件包含如下
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
3、模型预测出错
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象,
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
2、消息为self.sentinel标识,抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
异步线程检测stop_now属性被更新,抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if
self
.
stop_now
:
raise
ValueError
self
.
q
.
put
(
val
)
def
gen
():
try
:
ret
=
self
.
mfunc
(
callback
=
_callback
,
**
self
.
kwargs
)
except
ValueError
:
pass
except
:
traceback
.
print_exc
()
pass
self
.
q
.
put
(
self
.
sentinel
)
self
.
thread
=
Thread
(
target
=
gen
)
self
.
thread
.
start
()
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
obj
=
self
.
q
.
get
(
True
,
None
)
if
obj
is
self
.
sentinel
:
raise
StopIteration
else
:
return
obj
def
__del__
(
self
):
"""
暂无实现
:return:
"""
pass
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
""" break 后会执行 """
self
.
stop_now
=
True
class
BaseAnswer
(
ABC
):
class
BaseAnswer
(
ABC
):
...
@@ -25,17 +155,25 @@ class BaseAnswer(ABC):
...
@@ -25,17 +155,25 @@ class BaseAnswer(ABC):
@abstractmethod
@abstractmethod
def
_check_point
(
self
)
->
LoaderCheckPoint
:
def
_check_point
(
self
)
->
LoaderCheckPoint
:
"""Return _check_point of llm."""
"""Return _check_point of llm."""
def
generatorAnswer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,)
->
Generator
[
Any
,
str
,
bool
]:
def
generate_with_callback
(
callback
=
None
,
**
kwargs
):
kwargs
[
'generate_with_callback'
]
=
AnswerResultStream
(
callback_func
=
callback
)
self
.
_generate_answer
(
**
kwargs
)
@property
def
generate_with_streaming
(
**
kwargs
):
@abstractmethod
return
Iteratorize
(
generate_with_callback
,
kwargs
)
def
_history_len
(
self
)
->
int
:
"""Return _history_len of llm."""
@abstractmethod
with
generate_with_streaming
(
inputs
=
inputs
,
run_manager
=
run_manager
)
as
generator
:
def
set_history_len
(
self
,
history_len
:
int
)
->
None
:
for
answerResult
in
generator
:
"""Return _history_len of llm."""
if
answerResult
.
listenerToken
:
output
=
answerResult
.
listenerToken
.
input_ids
yield
answerResult
def
generatorAnswer
(
self
,
prompt
:
str
,
@abstractmethod
history
:
List
[
List
[
str
]]
=
[],
def
_generate_answer
(
self
,
streaming
:
bool
=
False
):
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
pass
pass
models/chatglm_llm.py
浏览文件 @
c5bc2178
from
abc
import
ABC
from
abc
import
ABC
from
langchain.llms.base
import
LLM
from
langchain.chains.base
import
Chain
from
typing
import
Optional
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
models.loader
import
LoaderCheckPoint
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
from
models.base
import
(
BaseAnswer
,
AnswerResult
)
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
class
ChatGLM
(
BaseAnswer
,
LLM
,
ABC
):
class
ChatGLM
LLMChain
(
BaseAnswer
,
Chain
,
ABC
):
max_token
:
int
=
10000
max_token
:
int
=
10000
temperature
:
float
=
0.01
temperature
:
float
=
0.01
top_p
=
0.9
# 相关度
top_p
=
0.4
# 候选词数量
top_k
=
10
checkPoint
:
LoaderCheckPoint
=
None
checkPoint
:
LoaderCheckPoint
=
None
# history = []
# history = []
history_len
:
int
=
10
history_len
:
int
=
10
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
self
.
checkPoint
=
checkPoint
@property
@property
def
_
llm
_type
(
self
)
->
str
:
def
_
chain
_type
(
self
)
->
str
:
return
"ChatGLM"
return
"ChatGLM
LLMChain
"
@property
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
return
self
.
checkPoint
@property
@property
def
_history_len
(
self
)
->
int
:
def
input_keys
(
self
)
->
List
[
str
]
:
return
self
.
history_len
"""Will be whatever keys the prompt expects.
def
set_history_len
(
self
,
history_len
:
int
=
10
)
->
None
:
:meta private:
self
.
history_len
=
history_len
"""
return
[
self
.
prompt_key
]
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
@property
print
(
f
"__call:{prompt}"
)
def
output_keys
(
self
)
->
List
[
str
]:
response
,
_
=
self
.
checkPoint
.
model
.
chat
(
"""Will always return text key.
self
.
checkPoint
.
tokenizer
,
prompt
,
history
=
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
)
print
(
f
"response:{response}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
return
response
def
generatorAnswer
(
self
,
prompt
:
str
,
:meta private:
history
:
List
[
List
[
str
]]
=
[],
"""
streaming
:
bool
=
False
):
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list
=
transformers
.
StoppingCriteriaList
()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue
=
AnswerResultQueueSentinelTokenListenerQueue
()
stopping_criteria_list
.
append
(
listenerQueue
)
if
streaming
:
if
streaming
:
history
+=
[[]]
history
+=
[[]]
for
inum
,
(
stream_resp
,
_
)
in
enumerate
(
self
.
checkPoint
.
model
.
stream_chat
(
for
inum
,
(
stream_resp
,
_
)
in
enumerate
(
self
.
checkPoint
.
model
.
stream_chat
(
self
.
checkPoint
.
tokenizer
,
self
.
checkPoint
.
tokenizer
,
prompt
,
prompt
,
history
=
history
[
-
self
.
history_len
:
-
1
]
if
self
.
history_len
>
1
else
[],
history
=
history
[
-
self
.
history_len
:
-
1
]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
stopping_criteria
=
stopping_criteria_list
)):
)):
# self.checkPoint.clear_torch_cache()
# self.checkPoint.clear_torch_cache()
history
[
-
1
]
=
[
prompt
,
stream_resp
]
history
[
-
1
]
=
[
prompt
,
stream_resp
]
answer_result
=
AnswerResult
()
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
stream_resp
}
answer_result
.
llm_output
=
{
"answer"
:
stream_resp
}
yield
answer_result
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
generate_with_callback
(
answer_result
)
self
.
checkPoint
.
clear_torch_cache
()
self
.
checkPoint
.
clear_torch_cache
()
else
:
else
:
response
,
_
=
self
.
checkPoint
.
model
.
chat
(
response
,
_
=
self
.
checkPoint
.
model
.
chat
(
...
@@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
...
@@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt
,
prompt
,
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
stopping_criteria
=
stopping_criteria_list
)
)
self
.
checkPoint
.
clear_torch_cache
()
self
.
checkPoint
.
clear_torch_cache
()
history
+=
[[
prompt
,
response
]]
history
+=
[[
prompt
,
response
]]
answer_result
=
AnswerResult
()
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
answer_result
.
llm_output
=
{
"answer"
:
response
}
yield
answer_result
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
generate_with_callback
(
answer_result
)
models/fastchat_openai_llm.py
浏览文件 @
c5bc2178
from
abc
import
ABC
from
abc
import
ABC
import
requests
from
langchain.chains.base
import
Chain
from
typing
import
Optional
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
,
Collection
from
langchain.llms.base
import
LLM
from
models.loader
import
LoaderCheckPoint
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
RemoteRpcModel
,
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
AnswerResult
)
from
models.base
import
(
BaseAnswer
,
from
typing
import
(
RemoteRpcModel
,
Collection
,
AnswerResult
,
Dict
AnswerResultStream
,
)
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
def
_build_message_template
()
->
Dict
[
str
,
str
]:
def
_build_message_template
()
->
Dict
[
str
,
str
]:
...
@@ -22,18 +22,42 @@ def _build_message_template() -> Dict[str, str]:
...
@@ -22,18 +22,42 @@ def _build_message_template() -> Dict[str, str]:
}
}
class
FastChatOpenAILLM
(
RemoteRpcModel
,
LLM
,
ABC
):
# 将历史对话数组转换为文本格式
def
build_message_list
(
query
,
history
:
List
[
List
[
str
]])
->
Collection
[
Dict
[
str
,
str
]]:
build_messages
:
Collection
[
Dict
[
str
,
str
]]
=
[]
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
old_query
system_build_message
=
_build_message_template
()
system_build_message
[
'role'
]
=
'system'
system_build_message
[
'content'
]
=
response
build_messages
.
append
(
user_build_message
)
build_messages
.
append
(
system_build_message
)
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
query
build_messages
.
append
(
user_build_message
)
return
build_messages
class
FastChatOpenAILLMChain
(
RemoteRpcModel
,
Chain
,
ABC
):
api_base_url
:
str
=
"http://localhost:8000/v1"
api_base_url
:
str
=
"http://localhost:8000/v1"
model_name
:
str
=
"chatglm-6b"
model_name
:
str
=
"chatglm-6b"
max_token
:
int
=
10000
max_token
:
int
=
10000
temperature
:
float
=
0.01
temperature
:
float
=
0.01
top_p
=
0.9
top_p
=
0.9
checkPoint
:
LoaderCheckPoint
=
None
checkPoint
:
LoaderCheckPoint
=
None
history
=
[]
#
history = []
history_len
:
int
=
10
history_len
:
int
=
10
api_key
:
str
=
""
api_key
:
str
=
""
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
,
checkPoint
:
LoaderCheckPoint
=
None
,
# api_base_url:str="http://localhost:8000/v1",
# api_base_url:str="http://localhost:8000/v1",
...
@@ -44,19 +68,28 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
...
@@ -44,19 +68,28 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
self
.
checkPoint
=
checkPoint
self
.
checkPoint
=
checkPoint
@property
@property
def
_
llm
_type
(
self
)
->
str
:
def
_
chain
_type
(
self
)
->
str
:
return
"
FastChat
"
return
"
LLamaLLMChain
"
@property
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
return
self
.
checkPoint
@property
@property
def
_history_len
(
self
)
->
int
:
def
input_keys
(
self
)
->
List
[
str
]
:
return
self
.
history_len
"""Will be whatever keys the prompt expects.
def
set_history_len
(
self
,
history_len
:
int
=
10
)
->
None
:
:meta private:
self
.
history_len
=
history_len
"""
return
[
self
.
prompt_key
]
@property
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
:meta private:
"""
return
[
self
.
output_key
]
@property
@property
def
_api_key
(
self
)
->
str
:
def
_api_key
(
self
)
->
str
:
...
@@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
...
@@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
def
call_model_name
(
self
,
model_name
):
def
call_model_name
(
self
,
model_name
):
self
.
model_name
=
model_name
self
.
model_name
=
model_name
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
print
(
f
"__call:{prompt}"
)
try
:
try
:
import
openai
# Not support yet
# openai.api_key = "EMPTY"
openai
.
key
=
self
.
api_key
openai
.
api_base
=
self
.
api_base_url
except
ImportError
:
raise
ValueError
(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
# create a chat completion
completion
=
openai
.
ChatCompletion
.
create
(
model
=
self
.
model_name
,
messages
=
self
.
build_message_list
(
prompt
)
)
print
(
f
"response:{completion.choices[0].message.content}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
return
completion
.
choices
[
0
]
.
message
.
content
# 将历史对话数组转换为文本格式
def
build_message_list
(
self
,
query
)
->
Collection
[
Dict
[
str
,
str
]]:
build_message_list
:
Collection
[
Dict
[
str
,
str
]]
=
[]
history
=
self
.
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[]
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
old_query
system_build_message
=
_build_message_template
()
system_build_message
[
'role'
]
=
'system'
system_build_message
[
'content'
]
=
response
build_message_list
.
append
(
user_build_message
)
build_message_list
.
append
(
system_build_message
)
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
query
build_message_list
.
append
(
user_build_message
)
return
build_message_list
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
try
:
import
openai
import
openai
# Not support yet
# Not support yet
# openai.api_key = "EMPTY"
# openai.api_key = "EMPTY"
...
@@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
...
@@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# create a chat completion
# create a chat completion
completion
=
openai
.
ChatCompletion
.
create
(
completion
=
openai
.
ChatCompletion
.
create
(
model
=
self
.
model_name
,
model
=
self
.
model_name
,
messages
=
self
.
build_message_list
(
prompt
)
messages
=
build_message_list
(
prompt
)
)
)
print
(
f
"response:{completion.choices[0].message.content}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
history
+=
[[
prompt
,
completion
.
choices
[
0
]
.
message
.
content
]]
history
+=
[[
prompt
,
completion
.
choices
[
0
]
.
message
.
content
]]
answer_result
=
AnswerResult
()
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
completion
.
choices
[
0
]
.
message
.
content
}
answer_result
.
llm_output
=
{
"answer"
:
completion
.
choices
[
0
]
.
message
.
content
}
generate_with_callback
(
answer_result
)
yield
answer_result
models/llama_llm.py
浏览文件 @
c5bc2178
from
abc
import
ABC
from
langchain.llms.base
import
LLM
from
abc
import
ABC
import
random
from
langchain.chains.base
import
Chain
import
torch
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
,
Union
import
transformers
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
typing
import
Optional
,
List
,
Dict
,
Any
,
Union
from
models.loader
import
LoaderCheckPoint
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
from
models.base
import
(
BaseAnswer
,
AnswerResult
)
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
def
__call__
(
self
,
input_ids
:
Union
[
torch
.
LongTensor
,
list
],
scores
:
Union
[
torch
.
FloatTensor
,
list
])
->
torch
.
FloatTensor
:
def
__call__
(
self
,
input_ids
:
Union
[
torch
.
LongTensor
,
list
],
scores
:
Union
[
torch
.
FloatTensor
,
list
])
->
torch
.
FloatTensor
:
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
input_ids
=
torch
.
tensor
(
input_ids
)
if
isinstance
(
input_ids
,
list
)
else
input_ids
input_ids
=
torch
.
tensor
(
input_ids
)
if
isinstance
(
input_ids
,
list
)
else
input_ids
scores
=
torch
.
tensor
(
scores
)
if
isinstance
(
scores
,
list
)
else
scores
scores
=
torch
.
tensor
(
scores
)
if
isinstance
(
scores
,
list
)
else
scores
if
torch
.
isnan
(
scores
)
.
any
()
or
torch
.
isinf
(
scores
)
.
any
():
if
torch
.
isnan
(
scores
)
.
any
()
or
torch
.
isinf
(
scores
)
.
any
():
scores
.
zero_
()
scores
.
zero_
()
scores
[
...
,
5
]
=
5e4
scores
[
...
,
5
]
=
5e4
return
scores
return
scores
class
LLamaLLM
(
BaseAnswer
,
LLM
,
ABC
):
class
LLamaLLM
Chain
(
BaseAnswer
,
Chain
,
ABC
):
checkPoint
:
LoaderCheckPoint
=
None
checkPoint
:
LoaderCheckPoint
=
None
# history = []
# history = []
history_len
:
int
=
3
history_len
:
int
=
3
...
@@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
...
@@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
min_length
:
int
=
0
min_length
:
int
=
0
logits_processor
:
LogitsProcessorList
=
None
logits_processor
:
LogitsProcessorList
=
None
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
None
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
None
eos_token_id
:
Optional
[
int
]
=
[
2
]
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
state
:
object
=
{
'max_new_tokens'
:
50
,
prompt_key
:
str
=
"prompt"
#: :meta private:
'seed'
:
1
,
output_key
:
str
=
"answer_result_stream"
#: :meta private:
'temperature'
:
0
,
'top_p'
:
0.1
,
'top_k'
:
40
,
'typical_p'
:
1
,
'repetition_penalty'
:
1.2
,
'encoder_repetition_penalty'
:
1
,
'no_repeat_ngram_size'
:
0
,
'min_length'
:
0
,
'penalty_alpha'
:
0
,
'num_beams'
:
1
,
'length_penalty'
:
1
,
'early_stopping'
:
False
,
'add_bos_token'
:
True
,
'ban_eos_token'
:
False
,
'truncation_length'
:
2048
,
'custom_stopping_strings'
:
''
,
'cpu_memory'
:
0
,
'auto_devices'
:
False
,
'disk'
:
False
,
'cpu'
:
False
,
'bf16'
:
False
,
'load_in_8bit'
:
False
,
'wbits'
:
'None'
,
'groupsize'
:
'None'
,
'model_type'
:
'None'
,
'pre_layer'
:
0
,
'gpu_memory_0'
:
0
}
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
self
.
checkPoint
=
checkPoint
@property
@property
def
_llm_type
(
self
)
->
str
:
def
_chain_type
(
self
)
->
str
:
return
"LLamaLLM"
return
"LLamaLLMChain"
@property
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return
[
self
.
prompt_key
]
@property
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
:meta private:
"""
return
[
self
.
output_key
]
@property
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
def
_check_point
(
self
)
->
LoaderCheckPoint
:
...
@@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
...
@@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
formatted_history
+=
"### Human:{}
\n
### Assistant:"
.
format
(
query
)
formatted_history
+=
"### Human:{}
\n
### Assistant:"
.
format
(
query
)
return
formatted_history
return
formatted_history
def
prepare_inputs_for_generation
(
self
,
def
_call
(
input_ids
:
torch
.
LongTensor
):
self
,
"""
inputs
:
Dict
[
str
,
Any
],
预生成注意力掩码和 输入序列中每个位置的索引的张量
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
# TODO 没有思路
)
->
Dict
[
str
,
Generator
]:
:return:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
"""
return
{
self
.
output_key
:
generator
}
mask_positions
=
torch
.
zeros
((
1
,
input_ids
.
shape
[
1
]),
dtype
=
input_ids
.
dtype
)
.
to
(
self
.
checkPoint
.
model
.
device
)
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
attention_mask
=
self
.
get_masks
(
input_ids
,
input_ids
.
device
)
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
position_ids
=
self
.
get_position_ids
(
input_ids
,
history
=
inputs
[
self
.
history_key
]
device
=
input_ids
.
device
,
streaming
=
inputs
[
self
.
streaming_key
]
mask_positions
=
mask_positions
prompt
=
inputs
[
self
.
prompt_key
]
)
return
input_ids
,
position_ids
,
attention_mask
@property
def
_history_len
(
self
)
->
int
:
return
self
.
history_len
def
set_history_len
(
self
,
history_len
:
int
=
10
)
->
None
:
self
.
history_len
=
history_len
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
print
(
f
"__call:{prompt}"
)
print
(
f
"__call:{prompt}"
)
# Create the StoppingCriteriaList with the stopping strings
self
.
stopping_criteria
=
transformers
.
StoppingCriteriaList
()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue
=
AnswerResultQueueSentinelTokenListenerQueue
()
self
.
stopping_criteria
.
append
(
listenerQueue
)
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
soft_prompt
=
self
.
history_to_text
(
query
=
prompt
,
history
=
history
)
if
self
.
logits_processor
is
None
:
if
self
.
logits_processor
is
None
:
self
.
logits_processor
=
LogitsProcessorList
()
self
.
logits_processor
=
LogitsProcessorList
()
self
.
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
self
.
logits_processor
.
append
(
InvalidScoreLogitsProcessor
())
...
@@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
...
@@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
"logits_processor"
:
self
.
logits_processor
}
"logits_processor"
:
self
.
logits_processor
}
# 向量转换
# 向量转换
input_ids
=
self
.
encode
(
prompt
,
add_bos_token
=
self
.
state
[
'add_bos_token'
],
truncation_length
=
self
.
max_new_tokens
)
input_ids
=
self
.
encode
(
soft_prompt
,
add_bos_token
=
self
.
checkPoint
.
tokenizer
.
add_bos_token
,
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
truncation_length
=
self
.
max_new_tokens
)
gen_kwargs
.
update
({
'inputs'
:
input_ids
})
gen_kwargs
.
update
({
'inputs'
:
input_ids
})
# 注意力掩码
# gen_kwargs.update({'attention_mask': attention_mask})
# gen_kwargs.update({'position_ids': position_ids})
if
self
.
stopping_criteria
is
None
:
self
.
stopping_criteria
=
transformers
.
StoppingCriteriaList
()
# 观测输出
# 观测输出
gen_kwargs
.
update
({
'stopping_criteria'
:
self
.
stopping_criteria
})
gen_kwargs
.
update
({
'stopping_criteria'
:
self
.
stopping_criteria
})
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
...
@@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
...
@@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
if
"llama_cpp"
in
self
.
checkPoint
.
model
.
__str__
():
if
"llama_cpp"
in
self
.
checkPoint
.
model
.
__str__
():
import
inspect
import
inspect
common_kwargs_keys
=
set
(
inspect
.
getfullargspec
(
self
.
checkPoint
.
model
.
generate
)
.
args
)
&
set
(
gen_kwargs
.
keys
())
common_kwargs_keys
=
set
(
inspect
.
getfullargspec
(
self
.
checkPoint
.
model
.
generate
)
.
args
)
&
set
(
common_kwargs
=
{
key
:
gen_kwargs
[
key
]
for
key
in
common_kwargs_keys
}
gen_kwargs
.
keys
())
#? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
common_kwargs
=
{
key
:
gen_kwargs
[
key
]
for
key
in
common_kwargs_keys
}
#?为什么会不支持GPU呢,不应该啊?
# ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
output_ids
=
torch
.
tensor
([
list
(
self
.
checkPoint
.
model
.
generate
(
input_id_i
.
cpu
(),
**
common_kwargs
))
for
input_id_i
in
input_ids
])
# ?为什么会不支持GPU呢,不应该啊?
output_ids
=
torch
.
tensor
(
[
list
(
self
.
checkPoint
.
model
.
generate
(
input_id_i
.
cpu
(),
**
common_kwargs
))
for
input_id_i
in
input_ids
])
else
:
else
:
output_ids
=
self
.
checkPoint
.
model
.
generate
(
**
gen_kwargs
)
output_ids
=
self
.
checkPoint
.
model
.
generate
(
**
gen_kwargs
)
...
@@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
...
@@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
reply
=
self
.
decode
(
output_ids
[
0
][
-
new_tokens
:])
reply
=
self
.
decode
(
output_ids
[
0
][
-
new_tokens
:])
print
(
f
"response:{reply}"
)
print
(
f
"response:{reply}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
return
reply
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
softprompt
=
self
.
history_to_text
(
prompt
,
history
=
history
)
response
=
self
.
_call
(
prompt
=
softprompt
,
stop
=
[
'
\n
###'
])
answer_result
=
AnswerResult
()
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
+
[[
prompt
,
response
]]
history
+=
[[
prompt
,
reply
]]
answer_result
.
llm_output
=
{
"answer"
:
response
}
answer_result
.
history
=
history
yield
answer_result
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
answer_result
.
llm_output
=
{
"answer"
:
reply
}
generate_with_callback
(
answer_result
)
models/loader/loader.py
浏览文件 @
c5bc2178
...
@@ -20,6 +20,7 @@ class LoaderCheckPoint:
...
@@ -20,6 +20,7 @@ class LoaderCheckPoint:
no_remote_model
:
bool
=
False
no_remote_model
:
bool
=
False
# 模型名称
# 模型名称
model_name
:
str
=
None
model_name
:
str
=
None
pretrained_model_name
:
str
=
None
tokenizer
:
object
=
None
tokenizer
:
object
=
None
# 模型全路径
# 模型全路径
model_path
:
str
=
None
model_path
:
str
=
None
...
@@ -67,48 +68,49 @@ class LoaderCheckPoint:
...
@@ -67,48 +68,49 @@ class LoaderCheckPoint:
self
.
load_in_8bit
=
params
.
get
(
'load_in_8bit'
,
False
)
self
.
load_in_8bit
=
params
.
get
(
'load_in_8bit'
,
False
)
self
.
bf16
=
params
.
get
(
'bf16'
,
False
)
self
.
bf16
=
params
.
get
(
'bf16'
,
False
)
def
_load_model_config
(
self
):
def
_load_model_config
(
self
,
model_name
):
if
self
.
model_path
:
if
self
.
model_path
:
self
.
model_path
=
re
.
sub
(
"
\
s"
,
""
,
self
.
model_path
)
self
.
model_path
=
re
.
sub
(
"
\
s"
,
""
,
self
.
model_path
)
checkpoint
=
Path
(
f
'{self.model_path}'
)
checkpoint
=
Path
(
f
'{self.model_path}'
)
else
:
else
:
if
not
self
.
no_remote_model
:
if
self
.
no_remote_model
:
checkpoint
=
model_name
else
:
raise
ValueError
(
raise
ValueError
(
"本地模型local_model_path未配置路径"
"本地模型local_model_path未配置路径"
)
)
else
:
checkpoint
=
self
.
pretrained_model_name
print
(
f
"load_model_config {checkpoint}..."
)
try
:
try
:
model_config
=
AutoConfig
.
from_pretrained
(
checkpoint
,
trust_remote_code
=
True
)
model_config
=
AutoConfig
.
from_pretrained
(
checkpoint
,
trust_remote_code
=
True
)
return
model_config
return
model_config
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
return
checkpoint
return
checkpoint
def
_load_model
(
self
,
model_name
):
def
_load_model
(
self
):
"""
"""
加载自定义位置的model
加载自定义位置的model
:param model_name:
:return:
:return:
"""
"""
print
(
f
"Loading {model_name}..."
)
t0
=
time
.
time
()
t0
=
time
.
time
()
if
self
.
model_path
:
if
self
.
model_path
:
self
.
model_path
=
re
.
sub
(
"
\
s"
,
""
,
self
.
model_path
)
self
.
model_path
=
re
.
sub
(
"
\
s"
,
""
,
self
.
model_path
)
checkpoint
=
Path
(
f
'{self.model_path}'
)
checkpoint
=
Path
(
f
'{self.model_path}'
)
else
:
else
:
if
not
self
.
no_remote_model
:
if
self
.
no_remote_model
:
checkpoint
=
model_name
else
:
raise
ValueError
(
raise
ValueError
(
"本地模型local_model_path未配置路径"
"本地模型local_model_path未配置路径"
)
)
else
:
checkpoint
=
self
.
pretrained_model_name
print
(
f
"Loading {checkpoint}..."
)
self
.
is_llamacpp
=
len
(
list
(
Path
(
f
'{checkpoint}'
)
.
glob
(
'ggml*.bin'
)))
>
0
self
.
is_llamacpp
=
len
(
list
(
Path
(
f
'{checkpoint}'
)
.
glob
(
'ggml*.bin'
)))
>
0
if
'chatglm'
in
model_name
.
lower
()
or
"chatyuan"
in
model_name
.
lower
():
if
'chatglm'
in
self
.
model_name
.
lower
()
or
"chatyuan"
in
self
.
model_name
.
lower
():
LoaderClass
=
AutoModel
LoaderClass
=
AutoModel
else
:
else
:
LoaderClass
=
AutoModelForCausalLM
LoaderClass
=
AutoModelForCausalLM
...
@@ -138,7 +140,7 @@ class LoaderCheckPoint:
...
@@ -138,7 +140,7 @@ class LoaderCheckPoint:
torch_dtype
=
torch
.
bfloat16
if
self
.
bf16
else
torch
.
float16
,
torch_dtype
=
torch
.
bfloat16
if
self
.
bf16
else
torch
.
float16
,
trust_remote_code
=
True
)
.
half
()
.
to
(
self
.
llm_device
)
trust_remote_code
=
True
)
.
half
()
.
to
(
self
.
llm_device
)
else
:
else
:
from
accelerate
import
dispatch_model
,
infer_auto_device_map
from
accelerate
import
dispatch_model
,
infer_auto_device_map
model
=
LoaderClass
.
from_pretrained
(
checkpoint
,
model
=
LoaderClass
.
from_pretrained
(
checkpoint
,
config
=
self
.
model_config
,
config
=
self
.
model_config
,
...
@@ -146,10 +148,10 @@ class LoaderCheckPoint:
...
@@ -146,10 +148,10 @@ class LoaderCheckPoint:
trust_remote_code
=
True
)
.
half
()
trust_remote_code
=
True
)
.
half
()
# 可传入device_map自定义每张卡的部署情况
# 可传入device_map自定义每张卡的部署情况
if
self
.
device_map
is
None
:
if
self
.
device_map
is
None
:
if
'chatglm'
in
model_name
.
lower
():
if
'chatglm'
in
self
.
model_name
.
lower
():
self
.
device_map
=
self
.
chatglm_auto_configure_device_map
(
num_gpus
)
self
.
device_map
=
self
.
chatglm_auto_configure_device_map
(
num_gpus
)
elif
'moss'
in
model_name
.
lower
():
elif
'moss'
in
self
.
model_name
.
lower
():
self
.
device_map
=
self
.
moss_auto_configure_device_map
(
num_gpus
,
model_name
)
self
.
device_map
=
self
.
moss_auto_configure_device_map
(
num_gpus
,
checkpoint
)
else
:
else
:
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
...
@@ -166,9 +168,9 @@ class LoaderCheckPoint:
...
@@ -166,9 +168,9 @@ class LoaderCheckPoint:
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
# 实测在bloom模型上如此
# 实测在bloom模型上如此
# self.device_map = infer_auto_device_map(model,
# self.device_map = infer_auto_device_map(model,
# dtype=torch.int8,
# dtype=torch.int8,
# no_split_module_classes=model._no_split_modules)
# no_split_module_classes=model._no_split_modules)
model
=
dispatch_model
(
model
,
device_map
=
self
.
device_map
)
model
=
dispatch_model
(
model
,
device_map
=
self
.
device_map
)
else
:
else
:
...
@@ -202,7 +204,7 @@ class LoaderCheckPoint:
...
@@ -202,7 +204,7 @@ class LoaderCheckPoint:
# tokenizer = model.tokenizer
# tokenizer = model.tokenizer
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
#* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
#
* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
model_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
model_name
)
return
model
,
tokenizer
return
model
,
tokenizer
...
@@ -231,7 +233,7 @@ class LoaderCheckPoint:
...
@@ -231,7 +233,7 @@ class LoaderCheckPoint:
llm_int8_enable_fp32_cpu_offload
=
False
)
llm_int8_enable_fp32_cpu_offload
=
False
)
with
init_empty_weights
():
with
init_empty_weights
():
model
=
LoaderClass
.
from_config
(
self
.
model_config
,
trust_remote_code
=
True
)
model
=
LoaderClass
.
from_config
(
self
.
model_config
,
trust_remote_code
=
True
)
model
.
tie_weights
()
model
.
tie_weights
()
if
self
.
device_map
is
not
None
:
if
self
.
device_map
is
not
None
:
params
[
'device_map'
]
=
self
.
device_map
params
[
'device_map'
]
=
self
.
device_map
...
@@ -321,7 +323,7 @@ class LoaderCheckPoint:
...
@@ -321,7 +323,7 @@ class LoaderCheckPoint:
return
device_map
return
device_map
def
moss_auto_configure_device_map
(
self
,
num_gpus
:
int
,
model_name
)
->
Dict
[
str
,
int
]:
def
moss_auto_configure_device_map
(
self
,
num_gpus
:
int
,
checkpoint
)
->
Dict
[
str
,
int
]:
try
:
try
:
from
accelerate
import
init_empty_weights
from
accelerate
import
init_empty_weights
...
@@ -336,16 +338,6 @@ class LoaderCheckPoint:
...
@@ -336,16 +338,6 @@ class LoaderCheckPoint:
"`pip install bitsandbytes``pip install accelerate`."
"`pip install bitsandbytes``pip install accelerate`."
)
from
exc
)
from
exc
if
self
.
model_path
:
checkpoint
=
Path
(
f
'{self.model_path}'
)
else
:
if
not
self
.
no_remote_model
:
checkpoint
=
model_name
else
:
raise
ValueError
(
"本地模型local_model_path未配置路径"
)
cls
=
get_class_from_dynamic_module
(
class_reference
=
"fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM"
,
cls
=
get_class_from_dynamic_module
(
class_reference
=
"fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM"
,
pretrained_model_name_or_path
=
checkpoint
)
pretrained_model_name_or_path
=
checkpoint
)
...
@@ -452,7 +444,7 @@ class LoaderCheckPoint:
...
@@ -452,7 +444,7 @@ class LoaderCheckPoint:
def
reload_model
(
self
):
def
reload_model
(
self
):
self
.
unload_model
()
self
.
unload_model
()
self
.
model_config
=
self
.
_load_model_config
(
self
.
model_name
)
self
.
model_config
=
self
.
_load_model_config
()
if
self
.
use_ptuning_v2
:
if
self
.
use_ptuning_v2
:
try
:
try
:
...
@@ -464,7 +456,7 @@ class LoaderCheckPoint:
...
@@ -464,7 +456,7 @@ class LoaderCheckPoint:
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"加载PrefixEncoder config.json失败"
)
print
(
"加载PrefixEncoder config.json失败"
)
self
.
model
,
self
.
tokenizer
=
self
.
_load_model
(
self
.
model_name
)
self
.
model
,
self
.
tokenizer
=
self
.
_load_model
()
if
self
.
lora
:
if
self
.
lora
:
self
.
_add_lora_to_model
([
self
.
lora
])
self
.
_add_lora_to_model
([
self
.
lora
])
...
...
models/moss_llm.py
浏览文件 @
c5bc2178
from
abc
import
ABC
from
abc
import
ABC
from
langchain.llms.base
import
LLM
from
langchain.chains.base
import
Chain
from
typing
import
Optional
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
,
Union
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
models.loader
import
LoaderCheckPoint
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
from
models.base
import
(
BaseAnswer
,
AnswerResult
)
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
import
torch
import
torch
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
META_INSTRUCTION
=
\
META_INSTRUCTION
=
\
"""You are an AI assistant whose name is MOSS.
"""You are an AI assistant whose name is MOSS.
...
@@ -20,41 +28,65 @@ META_INSTRUCTION = \
...
@@ -20,41 +28,65 @@ META_INSTRUCTION = \
Capabilities and tools that MOSS can possess.
Capabilities and tools that MOSS can possess.
"""
"""
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
class
MOSSLLM
(
BaseAnswer
,
LLM
,
ABC
):
class
MOSSLLM
Chain
(
BaseAnswer
,
Chain
,
ABC
):
max_token
:
int
=
2048
max_token
:
int
=
2048
temperature
:
float
=
0.7
temperature
:
float
=
0.7
top_p
=
0.8
top_p
=
0.8
# history = []
# history = []
checkPoint
:
LoaderCheckPoint
=
None
checkPoint
:
LoaderCheckPoint
=
None
history_len
:
int
=
10
history_len
:
int
=
10
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
self
.
checkPoint
=
checkPoint
@property
@property
def
_
llm
_type
(
self
)
->
str
:
def
_
chain
_type
(
self
)
->
str
:
return
"MOSS"
return
"MOSS
LLMChain
"
@property
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
def
input_keys
(
self
)
->
List
[
str
]:
return
self
.
checkPoint
"""Will be whatever keys the prompt expects.
:meta private:
"""
return
[
self
.
prompt_key
]
@property
@property
def
_history_len
(
self
)
->
int
:
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
return
self
.
history_len
:meta private:
"""
return
[
self
.
output_key
]
def
set_history_len
(
self
,
history_len
:
int
)
->
None
:
@property
self
.
history_len
=
history_len
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
def
_call
(
pass
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
generatorAnswer
(
self
,
prompt
:
str
,
def
_generate_answer
(
self
,
history
:
List
[
List
[
str
]]
=
[],
inputs
:
Dict
[
str
,
Any
],
streaming
:
bool
=
False
):
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
if
len
(
history
)
>
0
:
if
len
(
history
)
>
0
:
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[]
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[]
prompt_w_history
=
str
(
history
)
prompt_w_history
=
str
(
history
)
...
@@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
...
@@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
num_return_sequences
=
1
,
num_return_sequences
=
1
,
eos_token_id
=
106068
,
eos_token_id
=
106068
,
pad_token_id
=
self
.
checkPoint
.
tokenizer
.
pad_token_id
)
pad_token_id
=
self
.
checkPoint
.
tokenizer
.
pad_token_id
)
response
=
self
.
checkPoint
.
tokenizer
.
decode
(
outputs
[
0
][
inputs
.
input_ids
.
shape
[
1
]:],
skip_special_tokens
=
True
)
response
=
self
.
checkPoint
.
tokenizer
.
decode
(
outputs
[
0
][
inputs
.
input_ids
.
shape
[
1
]:],
skip_special_tokens
=
True
)
self
.
checkPoint
.
clear_torch_cache
()
self
.
checkPoint
.
clear_torch_cache
()
history
+=
[[
prompt
,
response
]]
history
+=
[[
prompt
,
response
]]
answer_result
=
AnswerResult
()
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
answer_result
.
llm_output
=
{
"answer"
:
response
}
yield
answer_result
generate_with_callback
(
answer_result
)
models/shared.py
浏览文件 @
c5bc2178
...
@@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
...
@@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if
use_ptuning_v2
:
if
use_ptuning_v2
:
loaderCheckPoint
.
use_ptuning_v2
=
use_ptuning_v2
loaderCheckPoint
.
use_ptuning_v2
=
use_ptuning_v2
# 如果指定了参数,则使用参数的配置
if
llm_model
:
if
llm_model
:
llm_model_info
=
llm_model_dict
[
llm_model
]
llm_model_info
=
llm_model_dict
[
llm_model
]
if
loaderCheckPoint
.
no_remote_model
:
loaderCheckPoint
.
model_name
=
llm_model_info
[
'name'
]
loaderCheckPoint
.
model_name
=
llm_model_info
[
'name'
]
else
:
loaderCheckPoint
.
pretrained_model_name
=
llm_model_info
[
'pretrained_model_name'
]
loaderCheckPoint
.
model_name
=
llm_model_info
[
'pretrained_model_name'
]
loaderCheckPoint
.
model_path
=
llm_model_info
[
"local_model_path"
]
loaderCheckPoint
.
model_path
=
llm_model_info
[
"local_model_path"
]
...
...
test/models/test_fastchat_openai_llm.py
deleted
100644 → 0
浏览文件 @
ca13ab81
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
'/../../'
)
import
asyncio
from
argparse
import
Namespace
from
models.loader.args
import
parser
from
models.loader
import
LoaderCheckPoint
import
models.shared
as
shared
async
def
dispatch
(
args
:
Namespace
):
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
history
=
[
(
"which city is this?"
,
"tokyo"
),
(
"why?"
,
"she's japanese"
),
]
for
answer_result
in
llm_model_ins
.
generatorAnswer
(
prompt
=
"你好? "
,
history
=
history
,
streaming
=
False
):
resp
=
answer_result
.
llm_output
[
"answer"
]
print
(
resp
)
if
__name__
==
'__main__'
:
args
=
None
args
=
parser
.
parse_args
(
args
=
[
'--model-dir'
,
'/media/checkpoint/'
,
'--model'
,
'fastchat-chatglm-6b'
,
'--no-remote-model'
])
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
run_until_complete
(
dispatch
(
args
))
webui.py
浏览文件 @
c5bc2178
...
@@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
...
@@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield
history
+
[[
query
,
yield
history
+
[[
query
,
"请选择知识库后进行测试,当前未选择知识库。"
]],
""
"请选择知识库后进行测试,当前未选择知识库。"
]],
""
else
:
else
:
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
query
,
history
=
history
,
streaming
=
streaming
):
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
query
,
"history"
:
history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
=
answer_result
.
history
history
[
-
1
][
-
1
]
=
resp
history
[
-
1
][
-
1
]
=
resp
...
@@ -101,11 +104,12 @@ def init_model():
...
@@ -101,11 +104,12 @@ def init_model():
args_dict
=
vars
(
args
)
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
try
:
try
:
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
)
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
)
generator
=
local_doc_qa
.
llm
.
generatorAnswer
(
"你好"
)
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
for
answer_result
in
generator
:
{
"prompt"
:
"你好"
,
"history"
:
[],
"streaming"
:
False
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
print
(
answer_result
.
llm_output
)
print
(
answer_result
.
llm_output
)
reply
=
"""模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
reply
=
"""模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger
.
info
(
reply
)
logger
.
info
(
reply
)
...
@@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
...
@@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def
get_vector_store
(
vs_id
,
files
,
sentence_size
,
history
,
one_conent
,
one_content_segmentation
):
def
get_vector_store
(
vs_id
,
files
,
sentence_size
,
history
,
one_conent
,
one_content_segmentation
):
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
filelist
=
[]
filelist
=
[]
if
local_doc_qa
.
llm
and
local_doc_qa
.
embeddings
:
if
local_doc_qa
.
llm
_model_chain
and
local_doc_qa
.
embeddings
:
if
isinstance
(
files
,
list
):
if
isinstance
(
files
,
list
):
for
file
in
files
:
for
file
in
files
:
filename
=
os
.
path
.
split
(
file
.
name
)[
-
1
]
filename
=
os
.
path
.
split
(
file
.
name
)[
-
1
]
...
@@ -165,7 +169,7 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
...
@@ -165,7 +169,7 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
def
change_vs_name_input
(
vs_id
,
history
):
def
change_vs_name_input
(
vs_id
,
history
):
if
vs_id
==
"新建知识库"
:
if
vs_id
==
"新建知识库"
:
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
False
),
None
,
history
,
\
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
False
),
None
,
history
,
\
gr
.
update
(
choices
=
[]),
gr
.
update
(
visible
=
False
)
gr
.
update
(
choices
=
[]),
gr
.
update
(
visible
=
False
)
else
:
else
:
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
...
@@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history):
...
@@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history):
def
add_vs_name
(
vs_name
,
chatbot
):
def
add_vs_name
(
vs_name
,
chatbot
):
if
vs_name
is
None
or
vs_name
.
strip
()
==
""
:
if
vs_name
is
None
or
vs_name
.
strip
()
==
""
:
vs_status
=
"知识库名称不能为空,请重新填写知识库名称"
vs_status
=
"知识库名称不能为空,请重新填写知识库名称"
chatbot
=
chatbot
+
[[
None
,
vs_status
]]
chatbot
=
chatbot
+
[[
None
,
vs_status
]]
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
...
@@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
...
@@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
def
refresh_vs_list
():
def
refresh_vs_list
():
return
gr
.
update
(
choices
=
get_vs_list
()),
gr
.
update
(
choices
=
get_vs_list
())
return
gr
.
update
(
choices
=
get_vs_list
()),
gr
.
update
(
choices
=
get_vs_list
())
def
delete_file
(
vs_id
,
files_to_delete
,
chatbot
):
def
delete_file
(
vs_id
,
files_to_delete
,
chatbot
):
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
content_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"content"
)
content_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"content"
)
...
@@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
...
@@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
rested_files
=
local_doc_qa
.
list_file_from_vector_store
(
vs_path
)
rested_files
=
local_doc_qa
.
list_file_from_vector_store
(
vs_path
)
if
"fail"
in
status
:
if
"fail"
in
status
:
vs_status
=
"文件删除失败。"
vs_status
=
"文件删除失败。"
elif
len
(
rested_files
)
>
0
:
elif
len
(
rested_files
)
>
0
:
vs_status
=
"文件删除成功。"
vs_status
=
"文件删除成功。"
else
:
else
:
vs_status
=
f
"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
vs_status
=
f
"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
logger
.
info
(
","
.
join
(
files_to_delete
)
+
vs_status
)
logger
.
info
(
","
.
join
(
files_to_delete
)
+
vs_status
)
chatbot
=
chatbot
+
[[
None
,
vs_status
]]
chatbot
=
chatbot
+
[[
None
,
vs_status
]]
return
gr
.
update
(
choices
=
local_doc_qa
.
list_file_from_vector_store
(
vs_path
),
value
=
[]),
chatbot
return
gr
.
update
(
choices
=
local_doc_qa
.
list_file_from_vector_store
(
vs_path
),
value
=
[]),
chatbot
...
@@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
...
@@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
status
=
f
"成功删除知识库{vs_id}"
status
=
f
"成功删除知识库{vs_id}"
logger
.
info
(
status
)
logger
.
info
(
status
)
chatbot
=
chatbot
+
[[
None
,
status
]]
chatbot
=
chatbot
+
[[
None
,
status
]]
return
gr
.
update
(
choices
=
get_vs_list
(),
value
=
get_vs_list
()[
0
]),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
\
return
gr
.
update
(
choices
=
get_vs_list
(),
value
=
get_vs_list
()[
0
]),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
\
gr
.
update
(
visible
=
False
),
chatbot
,
gr
.
update
(
visible
=
False
)
gr
.
update
(
visible
=
False
),
chatbot
,
gr
.
update
(
visible
=
False
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
...
@@ -333,7 +339,8 @@ default_theme_args = dict(
...
@@ -333,7 +339,8 @@ default_theme_args = dict(
with
gr
.
Blocks
(
css
=
block_css
,
theme
=
gr
.
themes
.
Default
(
**
default_theme_args
))
as
demo
:
with
gr
.
Blocks
(
css
=
block_css
,
theme
=
gr
.
themes
.
Default
(
**
default_theme_args
))
as
demo
:
vs_path
,
file_status
,
model_status
=
gr
.
State
(
vs_path
,
file_status
,
model_status
=
gr
.
State
(
os
.
path
.
join
(
KB_ROOT_PATH
,
get_vs_list
()[
0
],
"vector_store"
)
if
len
(
get_vs_list
())
>
1
else
""
),
gr
.
State
(
""
),
gr
.
State
(
os
.
path
.
join
(
KB_ROOT_PATH
,
get_vs_list
()[
0
],
"vector_store"
)
if
len
(
get_vs_list
())
>
1
else
""
),
gr
.
State
(
""
),
gr
.
State
(
model_status
)
model_status
)
gr
.
Markdown
(
webui_title
)
gr
.
Markdown
(
webui_title
)
with
gr
.
Tab
(
"对话"
):
with
gr
.
Tab
(
"对话"
):
...
...
webui_st.py
浏览文件 @
c5bc2178
...
@@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
...
@@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield
history
+
[[
query
,
yield
history
+
[[
query
,
"请选择知识库后进行测试,当前未选择知识库。"
]],
""
"请选择知识库后进行测试,当前未选择知识库。"
]],
""
else
:
else
:
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
query
,
history
=
history
,
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
streaming
=
streaming
):
{
"prompt"
:
query
,
"history"
:
history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
=
answer_result
.
history
history
[
-
1
][
-
1
]
=
resp
+
(
history
[
-
1
][
-
1
]
=
resp
+
(
...
@@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
...
@@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
args_dict
.
update
(
model
=
llm_model
)
args_dict
.
update
(
model
=
llm_model
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
try
:
try
:
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
,
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
,
embedding_model
=
embedding_model
)
embedding_model
=
embedding_model
)
generator
=
local_doc_qa
.
llm
.
generatorAnswer
(
"你好"
)
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
for
answer_result
in
generator
:
{
"prompt"
:
"你好"
,
"history"
:
[],
"streaming"
:
False
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
print
(
answer_result
.
llm_output
)
print
(
answer_result
.
llm_output
)
reply
=
"""模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
reply
=
"""模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger
.
info
(
reply
)
logger
.
info
(
reply
)
...
@@ -468,7 +470,7 @@ with st.sidebar:
...
@@ -468,7 +470,7 @@ with st.sidebar:
top_k
=
st
.
slider
(
'向量匹配数量'
,
1
,
20
,
VECTOR_SEARCH_TOP_K
)
top_k
=
st
.
slider
(
'向量匹配数量'
,
1
,
20
,
VECTOR_SEARCH_TOP_K
)
history_len
=
st
.
slider
(
history_len
=
st
.
slider
(
'LLM对话轮数'
,
1
,
50
,
LLM_HISTORY_LEN
)
# 也许要跟知识库分开设置
'LLM对话轮数'
,
1
,
50
,
LLM_HISTORY_LEN
)
# 也许要跟知识库分开设置
local_doc_qa
.
llm
.
set_history_len
(
history_len
)
#
local_doc_qa.llm.set_history_len(history_len)
chunk_conent
=
st
.
checkbox
(
'启用上下文关联'
,
False
)
chunk_conent
=
st
.
checkbox
(
'启用上下文关联'
,
False
)
st
.
text
(
''
)
st
.
text
(
''
)
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论