Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
J
jinchat-server
概览
概览
详情
活动
周期分析
版本库
存储库
文件
提交
分支
标签
贡献者
分支图
比较
统计图
问题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
aigc-pioneer
jinchat-server
Commits
c4ee36b8
提交
c4ee36b8
authored
5月 25, 2023
作者:
glide-the
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
删除 AnswerResultStream 、generate_with_callback收集器
上级
e7b06a90
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
21 行增加
和
204 行删除
+21
-204
local_doc_qa.py
chains/local_doc_qa.py
+1
-3
base.py
models/base.py
+0
-148
chatglm_llm.py
models/chatglm_llm.py
+7
-22
fastchat_llm.py
models/fastchat_llm.py
+4
-7
llama_llm.py
models/llama_llm.py
+4
-14
moss_llm.py
models/moss_llm.py
+4
-7
webui.py
webui.py
+1
-3
没有找到文件。
chains/local_doc_qa.py
浏览文件 @
c4ee36b8
...
...
@@ -12,9 +12,7 @@ from tqdm import tqdm
from
pypinyin
import
lazy_pinyin
from
loader
import
UnstructuredPaddleImageLoader
,
UnstructuredPaddlePDFLoader
from
models.base
import
(
BaseAnswer
,
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
AnswerResult
)
from
models.loader.args
import
parser
from
models.loader
import
LoaderCheckPoint
import
models.shared
as
shared
...
...
models/base.py
浏览文件 @
c4ee36b8
...
...
@@ -10,142 +10,12 @@ import transformers
from
models.loader
import
LoaderCheckPoint
class
ListenerToken
:
"""
观测结果
"""
input_ids
:
torch
.
LongTensor
_scores
:
torch
.
FloatTensor
def
__init__
(
self
,
input_ids
:
torch
.
LongTensor
,
_scores
:
torch
.
FloatTensor
):
self
.
input_ids
=
input_ids
self
.
_scores
=
_scores
class
AnswerResult
:
"""
消息实体
"""
history
:
List
[
List
[
str
]]
=
[]
llm_output
:
Optional
[
dict
]
=
None
listenerToken
:
ListenerToken
=
None
class
AnswerResultStream
:
def
__init__
(
self
,
callback_func
=
None
):
self
.
callback_func
=
callback_func
def
__call__
(
self
,
answerResult
:
AnswerResult
):
if
self
.
callback_func
is
not
None
:
self
.
callback_func
(
answerResult
)
class
AnswerResultQueueSentinelTokenListenerQueue
(
transformers
.
StoppingCriteria
):
"""
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
"""
listenerQueue
:
deque
=
deque
(
maxlen
=
1
)
def
__init__
(
self
):
transformers
.
StoppingCriteria
.
__init__
(
self
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
_scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self
.
listenerQueue
.
append
(
ListenerToken
(
input_ids
=
input_ids
,
_scores
=
_scores
))
return
False
class
Iteratorize
:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def
__init__
(
self
,
func
,
kwargs
=
{}):
self
.
mfunc
=
func
self
.
q
=
Queue
()
self
.
sentinel
=
object
()
self
.
kwargs
=
kwargs
self
.
stop_now
=
False
def
_callback
(
val
):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
结束条件包含如下
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
3、模型预测出错
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象,
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
2、消息为self.sentinel标识,抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
异步线程检测stop_now属性被更新,抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if
self
.
stop_now
:
raise
ValueError
self
.
q
.
put
(
val
)
def
gen
():
try
:
ret
=
self
.
mfunc
(
callback
=
_callback
,
**
self
.
kwargs
)
except
ValueError
:
pass
except
:
traceback
.
print_exc
()
pass
self
.
q
.
put
(
self
.
sentinel
)
self
.
thread
=
Thread
(
target
=
gen
)
self
.
thread
.
start
()
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
obj
=
self
.
q
.
get
(
True
,
None
)
if
obj
is
self
.
sentinel
:
raise
StopIteration
else
:
return
obj
def
__del__
(
self
):
"""
暂无实现
:return:
"""
pass
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
""" break 后会执行 """
self
.
stop_now
=
True
class
BaseAnswer
(
ABC
):
...
...
@@ -168,22 +38,4 @@ class BaseAnswer(ABC):
def
generatorAnswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
):
def
generate_with_callback
(
callback
=
None
,
**
kwargs
):
kwargs
[
'generate_with_callback'
]
=
AnswerResultStream
(
callback_func
=
callback
)
self
.
_generate_answer
(
**
kwargs
)
def
generate_with_streaming
(
**
kwargs
):
return
Iteratorize
(
generate_with_callback
,
kwargs
)
with
generate_with_streaming
(
prompt
=
prompt
,
history
=
history
,
streaming
=
streaming
)
as
generator
:
for
answerResult
in
generator
:
if
answerResult
.
listenerToken
:
output
=
answerResult
.
listenerToken
.
input_ids
yield
answerResult
@abstractmethod
def
_generate_answer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
pass
models/chatglm_llm.py
浏览文件 @
c4ee36b8
...
...
@@ -5,9 +5,7 @@ from langchain.llms.base import LLM
from
typing
import
Optional
,
List
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
AnswerResult
)
import
transformers
...
...
@@ -43,15 +41,9 @@ class ChatGLM(BaseAnswer, LLM, ABC):
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
pass
def
_generate_a
nswer
(
self
,
prompt
:
str
,
def
generatorA
nswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list
=
transformers
.
StoppingCriteriaList
()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue
=
AnswerResultQueueSentinelTokenListenerQueue
()
stopping_criteria_list
.
append
(
listenerQueue
)
streaming
:
bool
=
False
):
if
streaming
:
history
+=
[[]]
...
...
@@ -60,34 +52,27 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt
,
history
=
history
[
-
self
.
history_len
:
-
1
]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
,
stopping_criteria
=
stopping_criteria_list
temperature
=
self
.
temperature
)):
# self.checkPoint.clear_torch_cache()
history
[
-
1
]
=
[
prompt
,
stream_resp
]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
stream_resp
}
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
generate_with_callback
(
answer_result
)
yield
answer_result
else
:
response
,
_
=
self
.
checkPoint
.
model
.
chat
(
self
.
checkPoint
.
tokenizer
,
prompt
,
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
,
stopping_criteria
=
stopping_criteria_list
temperature
=
self
.
temperature
)
self
.
checkPoint
.
clear_torch_cache
()
history
+=
[[
prompt
,
response
]]
answer_result
=
AnswerResult
()
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
generate_with_callback
(
answer_result
)
yield
answer_result
models/fastchat_llm.py
浏览文件 @
c4ee36b8
...
...
@@ -5,9 +5,7 @@ from langchain.llms.base import LLM
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
AnswerResult
)
class
FastChatLLM
(
BaseAnswer
,
LLM
,
ABC
):
...
...
@@ -40,10 +38,9 @@ class FastChatLLM(BaseAnswer, LLM, ABC):
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
pass
def
_generate_a
nswer
(
self
,
prompt
:
str
,
def
generatorA
nswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
streaming
:
bool
=
False
):
response
=
"fastchat 响应结果"
history
+=
[[
prompt
,
response
]]
...
...
@@ -51,4 +48,4 @@ class FastChatLLM(BaseAnswer, LLM, ABC):
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
generate_with_callback
(
answer_result
)
yield
answer_result
models/llama_llm.py
浏览文件 @
c4ee36b8
...
...
@@ -9,9 +9,7 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
from
typing
import
Optional
,
List
,
Dict
,
Any
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
AnswerResult
)
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
...
...
@@ -178,23 +176,15 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
self
.
history
=
self
.
history
+
[[
None
,
reply
]]
return
reply
def
_generate_a
nswer
(
self
,
prompt
:
str
,
def
generatorA
nswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
streaming
:
bool
=
False
):
if
history
:
self
.
history
=
history
# Create the StoppingCriteriaList with the stopping strings
self
.
stopping_criteria
=
transformers
.
StoppingCriteriaList
()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue
=
AnswerResultQueueSentinelTokenListenerQueue
()
self
.
stopping_criteria
.
append
(
listenerQueue
)
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
softprompt
=
self
.
generate_softprompt_history_tensors
(
prompt
)
response
=
self
.
_call
(
prompt
=
softprompt
,
stop
=
[
'
\n
###'
])
answer_result
=
AnswerResult
()
answer_result
.
history
=
self
.
history
if
listenerQueue
.
listenerQueue
.
__len__
()
>
0
:
answer_result
.
listenerToken
=
listenerQueue
.
listenerQueue
.
pop
()
answer_result
.
llm_output
=
{
"answer"
:
response
}
generate_with_callback
(
answer_result
)
yield
answer_result
models/moss_llm.py
浏览文件 @
c4ee36b8
...
...
@@ -3,9 +3,7 @@ from langchain.llms.base import LLM
from
typing
import
Optional
,
List
from
models.loader
import
LoaderCheckPoint
from
models.base
import
(
BaseAnswer
,
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
AnswerResult
)
import
torch
...
...
@@ -53,10 +51,9 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
)
->
str
:
pass
def
_generate_a
nswer
(
self
,
prompt
:
str
,
def
generatorA
nswer
(
self
,
prompt
:
str
,
history
:
List
[
List
[
str
]]
=
[],
streaming
:
bool
=
False
,
generate_with_callback
:
AnswerResultStream
=
None
)
->
None
:
streaming
:
bool
=
False
):
if
len
(
history
)
>
0
:
history
=
history
[
-
self
.
history_len
:
-
1
]
if
self
.
history_len
>
0
else
[]
prompt_w_history
=
str
(
history
)
...
...
@@ -86,6 +83,6 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
answer_result
.
history
=
history
answer_result
.
llm_output
=
{
"answer"
:
response
}
generate_with_callback
(
answer_result
)
yield
answer_result
webui.py
浏览文件 @
c4ee36b8
...
...
@@ -6,9 +6,7 @@ from chains.local_doc_qa import LocalDocQA
from
configs.model_config
import
*
import
nltk
from
models.base
import
(
BaseAnswer
,
AnswerResult
,
AnswerResultStream
,
AnswerResultQueueSentinelTokenListenerQueue
)
AnswerResult
)
import
models.shared
as
shared
from
models.loader.args
import
parser
from
models.loader
import
LoaderCheckPoint
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论