Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
J
jinchat-server
概览
概览
详情
活动
周期分析
版本库
存储库
文件
提交
分支
标签
贡献者
分支图
比较
统计图
问题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
aigc-pioneer
jinchat-server
Commits
c5bc2178
提交
c5bc2178
authored
7月 12, 2023
作者:
glide-the
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
修改模型生成的调用方式,兼容Chain调用
修改模型切换的bug
上级
ca13ab81
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
513 行增加
和
329 行删除
+513
-329
.gitignore
.gitignore
+1
-2
api.py
api.py
+4
-3
local_doc_qa.py
chains/local_doc_qa.py
+13
-9
model_config.py
configs/model_config.py
+26
-20
__init__.py
models/__init__.py
+4
-4
__init__.py
models/base/__init__.py
+5
-3
base.py
models/base/base.py
+151
-13
chatglm_llm.py
models/chatglm_llm.py
+68
-31
fastchat_openai_llm.py
models/fastchat_openai_llm.py
+75
-69
llama_llm.py
models/llama_llm.py
+75
-84
loader.py
models/loader/loader.py
+0
-0
moss_llm.py
models/moss_llm.py
+53
-22
shared.py
models/shared.py
+3
-4
test_fastchat_openai_llm.py
test/models/test_fastchat_openai_llm.py
+0
-39
webui.py
webui.py
+27
-20
webui_st.py
webui_st.py
+8
-6
没有找到文件。
.gitignore
浏览文件 @
c5bc2178
...
...
@@ -174,4 +174,4 @@ embedding/*
pyrightconfig.json
loader/tmp_files
flagged/*
\ No newline at end of file
flagged/*
api.py
浏览文件 @
c5bc2178
...
...
@@ -384,8 +384,10 @@ async def chat(
],
),
):
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
question
,
history
=
history
,
streaming
=
True
):
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
question
,
"history"
:
history
,
"streaming"
:
True
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
pass
...
...
@@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
global
local_doc_qa
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
app
=
FastAPI
()
# Add CORS middleware to allow all origins
...
...
chains/local_doc_qa.py
浏览文件 @
c5bc2178
...
...
@@ -18,6 +18,7 @@ from agent import bing_search
from
langchain.docstore.document
import
Document
from
functools
import
lru_cache
from
textsplitter.zh_title_enhance
import
zh_title_enhance
from
langchain.chains.base
import
Chain
# patch HuggingFaceEmbeddings to make it hashable
...
...
@@ -119,7 +120,7 @@ def search_result2docs(search_results):
class
LocalDocQA
:
llm
:
BaseAnswer
=
None
llm
_model_chain
:
Chain
=
None
embeddings
:
object
=
None
top_k
:
int
=
VECTOR_SEARCH_TOP_K
chunk_size
:
int
=
CHUNK_SIZE
...
...
@@ -129,10 +130,10 @@ class LocalDocQA:
def
init_cfg
(
self
,
embedding_model
:
str
=
EMBEDDING_MODEL
,
embedding_device
=
EMBEDDING_DEVICE
,
llm_model
:
BaseAnswer
=
None
,
llm_model
:
Chain
=
None
,
top_k
=
VECTOR_SEARCH_TOP_K
,
):
self
.
llm
=
llm_model
self
.
llm
_model_chain
=
llm_model
self
.
embeddings
=
HuggingFaceEmbeddings
(
model_name
=
embedding_model_dict
[
embedding_model
],
model_kwargs
=
{
'device'
:
embedding_device
})
self
.
top_k
=
top_k
...
...
@@ -236,8 +237,10 @@ class LocalDocQA:
else
:
prompt
=
query
for
answer_result
in
self
.
llm
.
generatorAnswer
(
prompt
=
prompt
,
history
=
chat_history
,
streaming
=
streaming
):
answer_result_stream_result
=
self
.
llm_model_chain
(
{
"prompt"
:
prompt
,
"history"
:
chat_history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
[
-
1
][
0
]
=
query
...
...
@@ -276,8 +279,10 @@ class LocalDocQA:
result_docs
=
search_result2docs
(
results
)
prompt
=
generate_prompt
(
result_docs
,
query
)
for
answer_result
in
self
.
llm
.
generatorAnswer
(
prompt
=
prompt
,
history
=
chat_history
,
streaming
=
streaming
):
answer_result_stream_result
=
self
.
llm_model_chain
(
{
"prompt"
:
prompt
,
"history"
:
chat_history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
[
-
1
][
0
]
=
query
...
...
@@ -296,7 +301,7 @@ class LocalDocQA:
def
update_file_from_vector_store
(
self
,
filepath
:
str
or
List
[
str
],
vs_path
,
docs
:
List
[
Document
],):
docs
:
List
[
Document
],
):
vector_store
=
load_vector_store
(
vs_path
,
self
.
embeddings
)
status
=
vector_store
.
update_doc
(
filepath
,
docs
)
return
status
...
...
@@ -320,7 +325,6 @@ if __name__ == "__main__":
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
local_doc_qa
=
LocalDocQA
()
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
)
...
...
configs/model_config.py
浏览文件 @
c5bc2178
...
...
@@ -37,61 +37,67 @@ llm_model_dict = {
"name"
:
"chatglm-6b-int4-qe"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int4-qe"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm-6b-int4"
:
{
"name"
:
"chatglm-6b-int4"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int4"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm-6b-int8"
:
{
"name"
:
"chatglm-6b-int8"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b-int8"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm-6b"
:
{
"name"
:
"chatglm-6b"
,
"pretrained_model_name"
:
"THUDM/chatglm-6b"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm2-6b"
:
{
"name"
:
"chatglm2-6b"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm2-6b-int4"
:
{
"name"
:
"chatglm2-6b-int4"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b-int4"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatglm2-6b-int8"
:
{
"name"
:
"chatglm2-6b-int8"
,
"pretrained_model_name"
:
"THUDM/chatglm2-6b-int8"
,
"local_model_path"
:
None
,
"provides"
:
"ChatGLM"
"provides"
:
"ChatGLM
LLMChain
"
},
"chatyuan"
:
{
"name"
:
"chatyuan"
,
"pretrained_model_name"
:
"ClueAI/ChatYuan-large-v2"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
"moss"
:
{
"name"
:
"moss"
,
"pretrained_model_name"
:
"fnlp/moss-moon-003-sft"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
"vicuna-13b-hf"
:
{
"name"
:
"vicuna-13b-hf"
,
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"provides"
:
"LLamaLLM"
"provides"
:
"LLamaLLMChain"
},
"vicuna-7b-hf"
:
{
"name"
:
"vicuna-13b-hf"
,
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"provides"
:
"LLamaLLMChain"
},
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
...
...
@@ -101,7 +107,7 @@ llm_model_dict = {
"name"
:
"bloomz-7b1"
,
"pretrained_model_name"
:
"bigscience/bloomz-7b1"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
...
...
@@ -110,14 +116,14 @@ llm_model_dict = {
"name"
:
"bloom-3b"
,
"pretrained_model_name"
:
"bigscience/bloom-3b"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
"baichuan-7b"
:
{
"name"
:
"baichuan-7b"
,
"pretrained_model_name"
:
"baichuan-inc/baichuan-7B"
,
"local_model_path"
:
None
,
"provides"
:
"MOSSLLM"
"provides"
:
"MOSSLLM
Chain
"
},
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
"ggml-vicuna-13b-1.1-q5"
:
{
...
...
@@ -131,7 +137,7 @@ llm_model_dict = {
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
"local_model_path"
:
f
'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/'''
,
"provides"
:
"LLamaLLM"
"provides"
:
"LLamaLLM
Chain
"
},
# 通过 fastchat 调用的模型请参考如下格式
...
...
@@ -139,7 +145,7 @@ llm_model_dict = {
"name"
:
"chatglm-6b"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"chatglm-6b"
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
,
# "name"修改为fastchat服务中的"api_base_url"
"api_key"
:
"EMPTY"
},
...
...
@@ -147,7 +153,7 @@ llm_model_dict = {
"name"
:
"chatglm2-6b"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"chatglm2-6b"
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
# "name"修改为fastchat服务中的"api_base_url"
},
...
...
@@ -156,7 +162,7 @@ llm_model_dict = {
"name"
:
"vicuna-13b-hf"
,
# "name"修改为fastchat服务中的"model_name"
"pretrained_model_name"
:
"vicuna-13b-hf"
,
"local_model_path"
:
None
,
"provides"
:
"FastChatOpenAILLM
"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM
"
"provides"
:
"FastChatOpenAILLM
Chain"
,
# 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain
"
"api_base_url"
:
"http://localhost:8000/v1"
,
# "name"修改为fastchat服务中的"api_base_url"
"api_key"
:
"EMPTY"
},
...
...
@@ -165,13 +171,13 @@ llm_model_dict = {
# 则需要将urllib3版本修改为1.25.11
# 如果报出:raise NewConnectionError(
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
# Failed to establish a new connection: [WinError 10060]
# 则是因为内地和香港的IP都被OPENAI封了,需要挂切换为日本、新加坡等地
"openai-chatgpt-3.5"
:
{
"name"
:
"gpt-3.5-turbo"
,
"pretrained_model_name"
:
"gpt-3.5-turbo"
,
"provides"
:
"FastChatOpenAILLM"
,
"provides"
:
"FastChatOpenAILLM
Chain
"
,
"local_model_path"
:
None
,
"api_base_url"
:
"https://api.openapi.com/v1"
,
"api_key"
:
""
...
...
@@ -226,7 +232,7 @@ LLM_HISTORY_LEN = 3
VECTOR_SEARCH_TOP_K
=
5
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
VECTOR_SEARCH_SCORE_THRESHOLD
=
0
VECTOR_SEARCH_SCORE_THRESHOLD
=
39
0
NLTK_DATA_PATH
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)),
"nltk_data"
)
...
...
models/__init__.py
浏览文件 @
c5bc2178
from
.chatglm_llm
import
ChatGLM
from
.llama_llm
import
LLamaLLM
from
.
moss_llm
import
MOSSLLM
from
.
fastchat_openai_llm
import
FastChatOpenAILLM
from
.chatglm_llm
import
ChatGLM
LLMChain
from
.llama_llm
import
LLamaLLM
Chain
from
.
fastchat_openai_llm
import
FastChatOpenAILLMChain
from
.
moss_llm
import
MOSSLLMChain
models/base/__init__.py
浏览文件 @
c5bc2178
from
models.base.base
import
(
AnswerResult
,
BaseAnswer
)
BaseAnswer
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
from
models.base.remote_rpc_model
import
(
RemoteRpcModel
)
__all__
=
[
"AnswerResult"
,
"BaseAnswer"
,
"RemoteRpcModel"
,
"AnswerResultStream"
,
"AnswerResultQueueSentinelTokenListenerQueue"
]
models/base/base.py
浏览文件 @
c5bc2178
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
import
traceback
from
collections
import
deque
from
queue
import
Queue
from
threading
import
Thread
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
models.loader
import
LoaderCheckPoint
import
torch
import
transformers
from
models.loader
import
LoaderCheckPoint
class
ListenerToken
:
"""
观测结果
"""
input_ids
:
torch
.
LongTensor
_scores
:
torch
.
FloatTensor
def
__init__
(
self
,
input_ids
:
torch
.
LongTensor
,
_scores
:
torch
.
FloatTensor
):
self
.
input_ids
=
input_ids
self
.
_scores
=
_scores
class
AnswerResult
:
...
...
@@ -16,6 +29,123 @@ class AnswerResult:
"""
history
:
List
[
List
[
str
]]
=
[]
llm_output
:
Optional
[
dict
]
=
None
listenerToken
:
ListenerToken
=
None
class
AnswerResultStream
:
def
__init__
(
self
,
callback_func
=
None
):
self
.
callback_func
=
callback_func
def
__call__
(
self
,
answerResult
:
AnswerResult
):
if
self
.
callback_func
is
not
None
:
self
.
callback_func
(
answerResult
)
class
AnswerResultQueueSentinelTokenListenerQueue
(
transformers
.
StoppingCriteria
):
"""
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
"""
listenerQueue
:
deque
=
deque
(
maxlen
=
1
)
def
__init__
(
self
):
transformers
.
StoppingCriteria
.
__init__
(
self
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
_scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self
.
listenerQueue
.
append
(
ListenerToken
(
input_ids
=
input_ids
,
_scores
=
_scores
))
return
False
class
Iteratorize
:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def
__init__
(
self
,
func
,
kwargs
=
{}):
self
.
mfunc
=
func
self
.
q
=
Queue
()
self
.
sentinel
=
object
()
self
.
kwargs
=
kwargs
self
.
stop_now
=
False
def
_callback
(
val
):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
结束条件包含如下
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
3、模型预测出错
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象,
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
2、消息为self.sentinel标识,抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
异步线程检测stop_now属性被更新,抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if
self
.
stop_now
:
raise
ValueError
self
.
q
.
put
(
val
)
def
gen
():
try
:
ret
=
self
.
mfunc
(
callback
=
_callback
,
**
self
.
kwargs
)
except
ValueError
:
pass
except
:
traceback
.
print_exc
()
pass
self
.
q
.
put
(
self
.
sentinel
)
self
.
thread
=
Thread
(
target
=
gen
)
self
.
thread
.
start
()
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
obj
=
self
.
q
.
get
(
True
,
None
)
if
obj
is
self
.
sentinel
:
raise
StopIteration
else
:
return
obj
def
__del__
(
self
):
"""
暂无实现
:return:
"""
pass
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
""" break 后会执行 """
self
.
stop_now
=
True
class
BaseAnswer
(
ABC
):
...
...
@@ -25,17 +155,25 @@ class BaseAnswer(ABC):
@abstractmethod
def
_check_point
(
self
)
->
LoaderCheckPoint
:
"""Return _check_point of llm."""
def
generatorAnswer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,)
->
Generator
[
Any
,
str
,
bool
]:
def
generate_with_callback
(
callback
=
None
,
**
kwargs
):
kwargs
[
'generate_with_callback'
]
=
AnswerResultStream
(
callback_func
=
callback
)
self
.
_generate_answer
(
**
kwargs
)
@property
@abstractmethod
def
_history_len
(
self
)
->
int
:
"""Return _history_len of llm."""
def
generate_with_streaming
(
**
kwargs
):
return
Iteratorize
(
generate_with_callback
,
kwargs
)
@abstractmethod
def
set_history_len
(
self
,
history_len
:
int
)
->
None
:
"""Return _history_len of llm."""
with
generate_with_streaming
(
inputs
=
inputs
,
run_manager
=
run_manager
)
as
generator
:
for
answerResult
in
generator
:
if
answerResult
.
listenerToken
:
output
=
answerResult
.
listenerToken
.
input_ids
yield
answerResult
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
@abstractmethod
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
pass
models/chatglm_llm.py
浏览文件 @
c5bc2178
from
abc
import
ABC
from
langchain.llms.base
import
LLM
from
typing
import
Optional
,
List
from
langchain.chains.base
import
Chain
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
)
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
class
ChatGLM
(
BaseAnswer
,
LLM
,
ABC
):
class
ChatGLM
LLMChain
(
BaseAnswer
,
Chain
,
ABC
):
max_token
:
int
=
10000
temperature
:
float
=
0.01
top_p
=
0.9
# 相关度
top_p
=
0.4
# 候选词数量
top_k
=
10
checkPoint
:
LoaderCheckPoint
=
None
# history = []
history_len
:
int
=
10
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
@property
def
_
llm
_type
(
self
)
->
str
:
return
"ChatGLM"
def
_
chain
_type
(
self
)
->
str
:
return
"ChatGLM
LLMChain
"
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
@property
def
_history_len
(
self
)
->
int
:
return
self
.
history_len
def
input_keys
(
self
)
->
List
[
str
]
:
"""Will be whatever keys the prompt expects.
def
set_history_len
(
self
,
history_len
:
int
=
10
)
->
None
:
self
.
history_len
=
history_len
:meta private:
"""
return
[
self
.
prompt_key
]
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
print
(
f
"__call:{prompt}"
)
response
,
_
=
self
.
checkPoint
.
model
.
chat
(
self
.
checkPoint
.
tokenizer
,
prompt
,
history
=
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
)
print
(
f
"response:{response}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
return
response
@property
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
:meta private:
"""
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list
=
transformers
.
StoppingCriteriaList
()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue
=
AnswerResultQueueSentinelTokenListenerQueue
()
stopping_criteria_list
.
append
(
listenerQueue
)
if
streaming
:
history
+=
[[]]
for
inum
,
(
stream_resp
,
_
)
in
enumerate
(
self
.
checkPoint
.
model
.
stream_chat
(
self
.
checkPoint
.
tokenizer
,
prompt
,
history
=
history
[
-
self
.
history_len
:
-
1
]
if
self
.
history_len
>
1
else
[],
history
=
history
[
-
self
.
history_len
:
-
1
]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
stopping_criteria
=
stopping_criteria_list
)):
# self.checkPoint.clear_torch_cache()
history
[
-
1
]
=
[
prompt
,
stream_resp
]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
stream_resp
}
yield
answer_result
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
generate_with_callback
(
answer_result
)
self
.
checkPoint
.
clear_torch_cache
()
else
:
response
,
_
=
self
.
checkPoint
.
model
.
chat
(
...
...
@@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt
,
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
stopping_criteria
=
stopping_criteria_list
)
self
.
checkPoint
.
clear_torch_cache
()
history
+=
[[
prompt
,
response
]]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
yield
answer_result
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
generate_with_callback
(
answer_result
)
models/fastchat_openai_llm.py
浏览文件 @
c5bc2178
from
abc
import
ABC
import
requests
from
typing
import
Optional
,
List
from
langchain.llms.base
import
LLM
from
langchain.chains.base
import
Chain
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
,
Collection
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
RemoteRpcModel
,
AnswerResult
)
from
typing
import
(
Collection
,
Dict
)
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
models.base
import
(
BaseAnswer
,
RemoteRpcModel
,
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
def
_build_message_template
()
->
Dict
[
str
,
str
]:
...
...
@@ -22,41 +22,74 @@ def _build_message_template() -> Dict[str, str]:
}
class
FastChatOpenAILLM
(
RemoteRpcModel
,
LLM
,
ABC
):
# 将历史对话数组转换为文本格式
def
build_message_list
(
query
,
history
:
List
[
List
[
str
]])
->
Collection
[
Dict
[
str
,
str
]]:
build_messages
:
Collection
[
Dict
[
str
,
str
]]
=
[]
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
old_query
system_build_message
=
_build_message_template
()
system_build_message
[
'role'
]
=
'system'
system_build_message
[
'content'
]
=
response
build_messages
.
append
(
user_build_message
)
build_messages
.
append
(
system_build_message
)
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
query
build_messages
.
append
(
user_build_message
)
return
build_messages
class
FastChatOpenAILLMChain
(
RemoteRpcModel
,
Chain
,
ABC
):
api_base_url
:
str
=
"http://localhost:8000/v1"
model_name
:
str
=
"chatglm-6b"
max_token
:
int
=
10000
temperature
:
float
=
0.01
top_p
=
0.9
checkPoint
:
LoaderCheckPoint
=
None
history
=
[]
#
history = []
history_len
:
int
=
10
api_key
:
str
=
""
def
__init__
(
self
,
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
,
# api_base_url:str="http://localhost:8000/v1",
# model_name:str="chatglm-6b",
# api_key:str=""
# api_base_url:str="http://localhost:8000/v1",
# model_name:str="chatglm-6b",
# api_key:str=""
):
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
@property
def
_
llm
_type
(
self
)
->
str
:
return
"
FastChat
"
def
_
chain
_type
(
self
)
->
str
:
return
"
LLamaLLMChain
"
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
@property
def
_history_len
(
self
)
->
int
:
return
self
.
history_len
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return
[
self
.
prompt_key
]
@property
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
def
set_history_len
(
self
,
history_len
:
int
=
10
)
->
None
:
self
.
history_len
=
history_len
:meta private:
"""
return
[
self
.
output_key
]
@property
def
_api_key
(
self
)
->
str
:
...
...
@@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
def
call_model_name
(
self
,
model_name
):
self
.
model_name
=
model_name
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
try
:
import
openai
# Not support yet
# openai.api_key = "EMPTY"
openai
.
key
=
self
.
api_key
openai
.
api_base
=
self
.
api_base_url
except
ImportError
:
raise
ValueError
(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
# create a chat completion
completion
=
openai
.
ChatCompletion
.
create
(
model
=
self
.
model_name
,
messages
=
self
.
build_message_list
(
prompt
)
)
print
(
f
"response:{completion.choices[0].message.content}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
return
completion
.
choices
[
0
]
.
message
.
content
# 将历史对话数组转换为文本格式
def
build_message_list
(
self
,
query
)
->
Collection
[
Dict
[
str
,
str
]]:
build_message_list
:
Collection
[
Dict
[
str
,
str
]]
=
[]
history
=
self
.
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[]
for
i
,
(
old_query
,
response
)
in
enumerate
(
history
):
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
old_query
system_build_message
=
_build_message_template
()
system_build_message
[
'role'
]
=
'system'
system_build_message
[
'content'
]
=
response
build_message_list
.
append
(
user_build_message
)
build_message_list
.
append
(
system_build_message
)
user_build_message
=
_build_message_template
()
user_build_message
[
'role'
]
=
'user'
user_build_message
[
'content'
]
=
query
build_message_list
.
append
(
user_build_message
)
return
build_message_list
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
try
:
import
openai
# Not support yet
# openai.api_key = "EMPTY"
...
...
@@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# create a chat completion
completion
=
openai
.
ChatCompletion
.
create
(
model
=
self
.
model_name
,
messages
=
self
.
build_message_list
(
prompt
)
messages
=
build_message_list
(
prompt
)
)
print
(
f
"response:{completion.choices[0].message.content}"
)
print
(
f
"+++++++++++++++++++++++++++++++++++"
)
history
+=
[[
prompt
,
completion
.
choices
[
0
]
.
message
.
content
]]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
completion
.
choices
[
0
]
.
message
.
content
}
yield
answer_result
generate_with_callback
(
answer_result
)
models/llama_llm.py
浏览文件 @
c5bc2178
差异被折叠。
点击展开。
models/loader/loader.py
浏览文件 @
c5bc2178
差异被折叠。
点击展开。
models/moss_llm.py
浏览文件 @
c5bc2178
from
abc
import
ABC
from
langchain.llms.base
import
LLM
from
typing
import
Optional
,
List
from
langchain.chains.base
import
Chain
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Generator
,
Union
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
transformers.generation.logits_process
import
LogitsProcessor
from
transformers.generation.utils
import
LogitsProcessorList
,
StoppingCriteriaList
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
)
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
import
torch
import
transformers
import
torch
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
META_INSTRUCTION
=
\
"""You are an AI assistant whose name is MOSS.
...
...
@@ -20,41 +28,65 @@ META_INSTRUCTION = \
Capabilities and tools that MOSS can possess.
"""
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
class
MOSSLLM
(
BaseAnswer
,
LLM
,
ABC
):
class
MOSSLLM
Chain
(
BaseAnswer
,
Chain
,
ABC
):
max_token
:
int
=
2048
temperature
:
float
=
0.7
top_p
=
0.8
# history = []
checkPoint
:
LoaderCheckPoint
=
None
history_len
:
int
=
10
streaming_key
:
str
=
"streaming"
#: :meta private:
history_key
:
str
=
"history"
#: :meta private:
prompt_key
:
str
=
"prompt"
#: :meta private:
output_key
:
str
=
"answer_result_stream"
#: :meta private:
def
__init__
(
self
,
checkPoint
:
LoaderCheckPoint
=
None
):
super
()
.
__init__
()
self
.
checkPoint
=
checkPoint
@property
def
_
llm
_type
(
self
)
->
str
:
return
"MOSS"
def
_
chain
_type
(
self
)
->
str
:
return
"MOSS
LLMChain
"
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return
[
self
.
prompt_key
]
@property
def
_history_len
(
self
)
->
int
:
def
output_keys
(
self
)
->
List
[
str
]:
"""Will always return text key.
return
self
.
history_len
:meta private:
"""
return
[
self
.
output_key
]
def
set_history_len
(
self
,
history_len
:
int
)
->
None
:
self
.
history_len
=
history_len
@property
def
_check_point
(
self
)
->
LoaderCheckPoint
:
return
self
.
checkPoint
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
pass
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Generator
]:
generator
=
self
.
generatorAnswer
(
inputs
=
inputs
,
run_manager
=
run_manager
)
return
{
self
.
output_key
:
generator
}
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
def
_generate_answer
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
history
=
inputs
[
self
.
history_key
]
streaming
=
inputs
[
self
.
streaming_key
]
prompt
=
inputs
[
self
.
prompt_key
]
print
(
f
"__call:{prompt}"
)
if
len
(
history
)
>
0
:
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[]
prompt_w_history
=
str
(
history
)
...
...
@@ -66,7 +98,7 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
inputs
=
self
.
checkPoint
.
tokenizer
(
prompt_w_history
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
# max_length似乎可以设的小一些,而repetion_penalty应大一些,否则chatyuan,bloom等模型为满足max会重复输出
#
#
outputs
=
self
.
checkPoint
.
model
.
generate
(
inputs
.
input_ids
.
cuda
(),
attention_mask
=
inputs
.
attention_mask
.
cuda
(),
...
...
@@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
num_return_sequences
=
1
,
eos_token_id
=
106068
,
pad_token_id
=
self
.
checkPoint
.
tokenizer
.
pad_token_id
)
response
=
self
.
checkPoint
.
tokenizer
.
decode
(
outputs
[
0
][
inputs
.
input_ids
.
shape
[
1
]:],
skip_special_tokens
=
True
)
response
=
self
.
checkPoint
.
tokenizer
.
decode
(
outputs
[
0
][
inputs
.
input_ids
.
shape
[
1
]:],
skip_special_tokens
=
True
)
self
.
checkPoint
.
clear_torch_cache
()
history
+=
[[
prompt
,
response
]]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
yield
answer_result
generate_with_callback
(
answer_result
)
models/shared.py
浏览文件 @
c5bc2178
...
...
@@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if
use_ptuning_v2
:
loaderCheckPoint
.
use_ptuning_v2
=
use_ptuning_v2
# 如果指定了参数,则使用参数的配置
if
llm_model
:
llm_model_info
=
llm_model_dict
[
llm_model
]
if
loaderCheckPoint
.
no_remote_model
:
loaderCheckPoint
.
model_name
=
llm_model_info
[
'name'
]
else
:
loaderCheckPoint
.
model_name
=
llm_model_info
[
'pretrained_model_name'
]
loaderCheckPoint
.
model_name
=
llm_model_info
[
'name'
]
loaderCheckPoint
.
pretrained_model_name
=
llm_model_info
[
'pretrained_model_name'
]
loaderCheckPoint
.
model_path
=
llm_model_info
[
"local_model_path"
]
...
...
test/models/test_fastchat_openai_llm.py
deleted
100644 → 0
浏览文件 @
ca13ab81
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
'/../../'
)
import
asyncio
from
argparse
import
Namespace
from
models.loader.args
import
parser
from
models.loader
import
LoaderCheckPoint
import
models.shared
as
shared
async
def
dispatch
(
args
:
Namespace
):
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
history
=
[
(
"which city is this?"
,
"tokyo"
),
(
"why?"
,
"she's japanese"
),
]
for
answer_result
in
llm_model_ins
.
generatorAnswer
(
prompt
=
"你好? "
,
history
=
history
,
streaming
=
False
):
resp
=
answer_result
.
llm_output
[
"answer"
]
print
(
resp
)
if
__name__
==
'__main__'
:
args
=
None
args
=
parser
.
parse_args
(
args
=
[
'--model-dir'
,
'/media/checkpoint/'
,
'--model'
,
'fastchat-chatglm-6b'
,
'--no-remote-model'
])
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
run_until_complete
(
dispatch
(
args
))
webui.py
浏览文件 @
c5bc2178
...
...
@@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield
history
+
[[
query
,
"请选择知识库后进行测试,当前未选择知识库。"
]],
""
else
:
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
query
,
history
=
history
,
streaming
=
streaming
):
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
query
,
"history"
:
history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
[
-
1
][
-
1
]
=
resp
...
...
@@ -101,11 +104,12 @@ def init_model():
args_dict
=
vars
(
args
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
try
:
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
)
generator
=
local_doc_qa
.
llm
.
generatorAnswer
(
"你好"
)
for
answer_result
in
generator
:
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
"你好"
,
"history"
:
[],
"streaming"
:
False
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
print
(
answer_result
.
llm_output
)
reply
=
"""模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger
.
info
(
reply
)
...
...
@@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def
get_vector_store
(
vs_id
,
files
,
sentence_size
,
history
,
one_conent
,
one_content_segmentation
):
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
filelist
=
[]
if
local_doc_qa
.
llm
and
local_doc_qa
.
embeddings
:
if
local_doc_qa
.
llm
_model_chain
and
local_doc_qa
.
embeddings
:
if
isinstance
(
files
,
list
):
for
file
in
files
:
filename
=
os
.
path
.
split
(
file
.
name
)[
-
1
]
...
...
@@ -165,8 +169,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
def
change_vs_name_input
(
vs_id
,
history
):
if
vs_id
==
"新建知识库"
:
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
False
),
None
,
history
,
\
gr
.
update
(
choices
=
[]),
gr
.
update
(
visible
=
False
)
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
False
),
None
,
history
,
\
gr
.
update
(
choices
=
[]),
gr
.
update
(
visible
=
False
)
else
:
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
if
"index.faiss"
in
os
.
listdir
(
vs_path
):
...
...
@@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history):
def
add_vs_name
(
vs_name
,
chatbot
):
if
vs_name
is
None
or
vs_name
.
strip
()
==
""
:
if
vs_name
is
None
or
vs_name
.
strip
()
==
""
:
vs_status
=
"知识库名称不能为空,请重新填写知识库名称"
chatbot
=
chatbot
+
[[
None
,
vs_status
]]
return
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
gr
.
update
(
...
...
@@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
def
refresh_vs_list
():
return
gr
.
update
(
choices
=
get_vs_list
()),
gr
.
update
(
choices
=
get_vs_list
())
def
delete_file
(
vs_id
,
files_to_delete
,
chatbot
):
vs_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"vector_store"
)
content_path
=
os
.
path
.
join
(
KB_ROOT_PATH
,
vs_id
,
"content"
)
...
...
@@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
rested_files
=
local_doc_qa
.
list_file_from_vector_store
(
vs_path
)
if
"fail"
in
status
:
vs_status
=
"文件删除失败。"
elif
len
(
rested_files
)
>
0
:
elif
len
(
rested_files
)
>
0
:
vs_status
=
"文件删除成功。"
else
:
vs_status
=
f
"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
logger
.
info
(
","
.
join
(
files_to_delete
)
+
vs_status
)
logger
.
info
(
","
.
join
(
files_to_delete
)
+
vs_status
)
chatbot
=
chatbot
+
[[
None
,
vs_status
]]
return
gr
.
update
(
choices
=
local_doc_qa
.
list_file_from_vector_store
(
vs_path
),
value
=
[]),
chatbot
...
...
@@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
status
=
f
"成功删除知识库{vs_id}"
logger
.
info
(
status
)
chatbot
=
chatbot
+
[[
None
,
status
]]
return
gr
.
update
(
choices
=
get_vs_list
(),
value
=
get_vs_list
()[
0
]),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
\
return
gr
.
update
(
choices
=
get_vs_list
(),
value
=
get_vs_list
()[
0
]),
gr
.
update
(
visible
=
True
),
gr
.
update
(
visible
=
True
),
\
gr
.
update
(
visible
=
False
),
chatbot
,
gr
.
update
(
visible
=
False
)
except
Exception
as
e
:
logger
.
error
(
e
)
...
...
@@ -333,7 +339,8 @@ default_theme_args = dict(
with
gr
.
Blocks
(
css
=
block_css
,
theme
=
gr
.
themes
.
Default
(
**
default_theme_args
))
as
demo
:
vs_path
,
file_status
,
model_status
=
gr
.
State
(
os
.
path
.
join
(
KB_ROOT_PATH
,
get_vs_list
()[
0
],
"vector_store"
)
if
len
(
get_vs_list
())
>
1
else
""
),
gr
.
State
(
""
),
gr
.
State
(
os
.
path
.
join
(
KB_ROOT_PATH
,
get_vs_list
()[
0
],
"vector_store"
)
if
len
(
get_vs_list
())
>
1
else
""
),
gr
.
State
(
""
),
gr
.
State
(
model_status
)
gr
.
Markdown
(
webui_title
)
with
gr
.
Tab
(
"对话"
):
...
...
@@ -386,8 +393,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
load_folder_button
=
gr
.
Button
(
"上传文件夹并加载知识库"
)
with
gr
.
Tab
(
"删除文件"
):
files_to_delete
=
gr
.
CheckboxGroup
(
choices
=
[],
label
=
"请从知识库已有文件中选择要删除的文件"
,
interactive
=
True
)
label
=
"请从知识库已有文件中选择要删除的文件"
,
interactive
=
True
)
delete_file_button
=
gr
.
Button
(
"从知识库中删除选中文件"
)
vs_refresh
.
click
(
fn
=
refresh_vs_list
,
inputs
=
[],
...
...
@@ -455,9 +462,9 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
with
vs_setting
:
vs_refresh
=
gr
.
Button
(
"更新已有知识库选项"
)
select_vs_test
=
gr
.
Dropdown
(
get_vs_list
(),
label
=
"请选择要加载的知识库"
,
interactive
=
True
,
value
=
get_vs_list
()[
0
]
if
len
(
get_vs_list
())
>
0
else
None
)
label
=
"请选择要加载的知识库"
,
interactive
=
True
,
value
=
get_vs_list
()[
0
]
if
len
(
get_vs_list
())
>
0
else
None
)
vs_name
=
gr
.
Textbox
(
label
=
"请输入新建知识库名称,当前知识库命名暂不支持中文"
,
lines
=
1
,
interactive
=
True
,
...
...
@@ -497,8 +504,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
inputs
=
[
vs_name
,
chatbot
],
outputs
=
[
select_vs_test
,
vs_name
,
vs_add
,
file2vs
,
chatbot
])
select_vs_test
.
change
(
fn
=
change_vs_name_input
,
inputs
=
[
select_vs_test
,
chatbot
],
outputs
=
[
vs_name
,
vs_add
,
file2vs
,
vs_path
,
chatbot
])
inputs
=
[
select_vs_test
,
chatbot
],
outputs
=
[
vs_name
,
vs_add
,
file2vs
,
vs_path
,
chatbot
])
load_file_button
.
click
(
get_vector_store
,
show_progress
=
True
,
inputs
=
[
select_vs_test
,
files
,
sentence_size
,
chatbot
,
vs_add
,
vs_add
],
...
...
webui_st.py
浏览文件 @
c5bc2178
...
...
@@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield
history
+
[[
query
,
"请选择知识库后进行测试,当前未选择知识库。"
]],
""
else
:
for
answer_result
in
local_doc_qa
.
llm
.
generatorAnswer
(
prompt
=
query
,
history
=
history
,
streaming
=
streaming
):
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
query
,
"history"
:
history
,
"streaming"
:
streaming
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
resp
=
answer_result
.
llm_output
[
"answer"
]
history
=
answer_result
.
history
history
[
-
1
][
-
1
]
=
resp
+
(
...
...
@@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
args_dict
.
update
(
model
=
llm_model
)
shared
.
loaderCheckPoint
=
LoaderCheckPoint
(
args_dict
)
llm_model_ins
=
shared
.
loaderLLM
()
llm_model_ins
.
set_history_len
(
LLM_HISTORY_LEN
)
try
:
local_doc_qa
.
init_cfg
(
llm_model
=
llm_model_ins
,
embedding_model
=
embedding_model
)
generator
=
local_doc_qa
.
llm
.
generatorAnswer
(
"你好"
)
for
answer_result
in
generator
:
answer_result_stream_result
=
local_doc_qa
.
llm_model_chain
(
{
"prompt"
:
"你好"
,
"history"
:
[],
"streaming"
:
False
})
for
answer_result
in
answer_result_stream_result
[
'answer_result_stream'
]:
print
(
answer_result
.
llm_output
)
reply
=
"""模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger
.
info
(
reply
)
...
...
@@ -468,7 +470,7 @@ with st.sidebar:
top_k
=
st
.
slider
(
'向量匹配数量'
,
1
,
20
,
VECTOR_SEARCH_TOP_K
)
history_len
=
st
.
slider
(
'LLM对话轮数'
,
1
,
50
,
LLM_HISTORY_LEN
)
# 也许要跟知识库分开设置
local_doc_qa
.
llm
.
set_history_len
(
history_len
)
#
local_doc_qa.llm.set_history_len(history_len)
chunk_conent
=
st
.
checkbox
(
'启用上下文关联'
,
False
)
st
.
text
(
''
)
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论