Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
J
jinchat-server
概览
概览
详情
活动
周期分析
版本库
存储库
文件
提交
分支
标签
贡献者
分支图
比较
统计图
问题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
aigc-pioneer
jinchat-server
Commits
b4aefca5
提交
b4aefca5
authored
4月 26, 2023
作者:
imClumsyPanda
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add stream support to cli_demo.py
上级
88ab9a1d
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
91 行增加
和
85 行删除
+91
-85
local_doc_qa.py
chains/local_doc_qa.py
+37
-46
cli_demo.py
cli_demo.py
+12
-6
chatglm_llm.py
models/chatglm_llm.py
+42
-33
没有找到文件。
chains/local_doc_qa.py
浏览文件 @
b4aefca5
...
@@ -2,7 +2,6 @@ from langchain.chains import RetrievalQA
...
@@ -2,7 +2,6 @@ from langchain.chains import RetrievalQA
from
langchain.prompts
import
PromptTemplate
from
langchain.prompts
import
PromptTemplate
from
langchain.embeddings.huggingface
import
HuggingFaceEmbeddings
from
langchain.embeddings.huggingface
import
HuggingFaceEmbeddings
from
langchain.vectorstores
import
FAISS
from
langchain.vectorstores
import
FAISS
from
langchain.vectorstores.base
import
VectorStoreRetriever
from
langchain.document_loaders
import
UnstructuredFileLoader
from
langchain.document_loaders
import
UnstructuredFileLoader
from
models.chatglm_llm
import
ChatGLM
from
models.chatglm_llm
import
ChatGLM
import
sentence_transformers
import
sentence_transformers
...
@@ -34,22 +33,20 @@ def load_file(filepath):
...
@@ -34,22 +33,20 @@ def load_file(filepath):
docs
=
loader
.
load_and_split
(
text_splitter
=
textsplitter
)
docs
=
loader
.
load_and_split
(
text_splitter
=
textsplitter
)
return
docs
return
docs
def
generate_prompt
(
related_docs
:
List
[
str
],
def
get_relevant_documents
(
self
,
query
:
str
)
->
List
[
Document
]:
query
:
str
,
if
self
.
search_type
==
"similarity"
:
prompt_template
=
PROMPT_TEMPLATE
)
->
str
:
docs
=
self
.
vectorstore
.
_similarity_search_with_relevance_scores
(
query
,
**
self
.
search_kwargs
)
context
=
"
\n
"
.
join
([
doc
.
page_content
for
doc
in
related_docs
])
for
doc
in
docs
:
prompt
=
prompt_template
.
replace
(
"{question}"
,
query
)
.
replace
(
"{context}"
,
context
)
doc
[
0
]
.
metadata
[
"score"
]
=
doc
[
1
]
return
prompt
docs
=
[
doc
[
0
]
for
doc
in
docs
]
elif
self
.
search_type
==
"mmr"
:
docs
=
self
.
vectorstore
.
max_marginal_relevance_search
(
query
,
**
self
.
search_kwargs
)
else
:
raise
ValueError
(
f
"search_type of {self.search_type} not allowed."
)
return
docs
def
get_docs_with_score
(
docs_with_score
):
docs
=
[]
for
doc
,
score
in
docs_with_score
:
doc
.
metadata
[
"score"
]
=
score
docs
.
append
(
doc
)
return
docs
class
LocalDocQA
:
class
LocalDocQA
:
llm
:
object
=
None
llm
:
object
=
None
...
@@ -73,8 +70,6 @@ class LocalDocQA:
...
@@ -73,8 +70,6 @@ class LocalDocQA:
self
.
embeddings
=
HuggingFaceEmbeddings
(
model_name
=
embedding_model_dict
[
embedding_model
],
self
.
embeddings
=
HuggingFaceEmbeddings
(
model_name
=
embedding_model_dict
[
embedding_model
],
model_kwargs
=
{
'device'
:
embedding_device
})
model_kwargs
=
{
'device'
:
embedding_device
})
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
# device=embedding_device)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
init_knowledge_vector_store
(
self
,
def
init_knowledge_vector_store
(
self
,
...
@@ -134,34 +129,30 @@ class LocalDocQA:
...
@@ -134,34 +129,30 @@ class LocalDocQA:
def
get_knowledge_based_answer
(
self
,
def
get_knowledge_based_answer
(
self
,
query
,
query
,
vs_path
,
vs_path
,
chat_history
=
[],
):
chat_history
=
[],
prompt_template
=
"""基于以下已知信息,简洁和专业的来回答用户的问题。
streaming
=
True
):
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
self
.
llm
.
streaming
=
streaming
已知内容:
{context}
问题:
{question}"""
prompt
=
PromptTemplate
(
template
=
prompt_template
,
input_variables
=
[
"context"
,
"question"
]
)
self
.
llm
.
history
=
chat_history
vector_store
=
FAISS
.
load_local
(
vs_path
,
self
.
embeddings
)
vector_store
=
FAISS
.
load_local
(
vs_path
,
self
.
embeddings
)
vs_r
=
vector_store
.
as_retriever
(
search_type
=
"mmr"
,
related_docs_with_score
=
vector_store
.
similarity_search_with_score
(
query
,
search_kwargs
=
{
"k"
:
self
.
top_k
})
k
=
self
.
top_k
)
# VectorStoreRetriever.get_relevant_documents = get_relevant_documents
related_docs
=
get_docs_with_score
(
related_docs_with_score
)
knowledge_chain
=
RetrievalQA
.
from_llm
(
prompt
=
generate_prompt
(
related_docs
,
query
)
llm
=
self
.
llm
,
retriever
=
vs_r
,
prompt
=
prompt
)
knowledge_chain
.
combine_documents_chain
.
document_prompt
=
PromptTemplate
(
input_variables
=
[
"page_content"
],
template
=
"{page_content}"
)
knowledge_chain
.
return_source_documents
=
True
if
streaming
:
result
=
knowledge_chain
({
"query"
:
query
})
for
result
,
history
in
self
.
llm
.
_call
(
prompt
=
prompt
,
self
.
llm
.
history
[
-
1
][
0
]
=
query
history
=
chat_history
):
return
result
,
self
.
llm
.
history
history
[
-
1
]
=
list
(
history
[
-
1
])
history
[
-
1
][
0
]
=
query
response
=
{
"query"
:
query
,
"result"
:
result
,
"source_documents"
:
related_docs
}
yield
response
,
history
else
:
result
,
history
=
self
.
llm
.
_call
(
prompt
=
prompt
,
history
=
chat_history
)
history
[
-
1
]
=
list
(
history
[
-
1
])
history
[
-
1
][
0
]
=
query
response
=
{
"query"
:
query
,
"result"
:
result
,
"source_documents"
:
related_docs
}
return
response
,
history
cli_demo.py
浏览文件 @
b4aefca5
...
@@ -28,10 +28,16 @@ if __name__ == "__main__":
...
@@ -28,10 +28,16 @@ if __name__ == "__main__":
history
=
[]
history
=
[]
while
True
:
while
True
:
query
=
input
(
"Input your question 请输入问题:"
)
query
=
input
(
"Input your question 请输入问题:"
)
resp
,
history
=
local_doc_qa
.
get_knowledge_based_answer
(
query
=
query
,
last_print_len
=
0
vs_path
=
vs_path
,
for
resp
,
history
in
local_doc_qa
.
get_knowledge_based_answer
(
query
=
query
,
chat_history
=
history
)
vs_path
=
vs_path
,
chat_history
=
history
,
streaming
=
True
):
print
(
resp
[
"result"
][
last_print_len
:],
end
=
""
,
flush
=
True
)
last_print_len
=
len
(
resp
[
"result"
])
if
REPLY_WITH_SOURCE
:
if
REPLY_WITH_SOURCE
:
print
(
resp
)
source_text
=
[
f
"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:
\n\n
{doc.page_content}
\n\n
"""
else
:
# f"""相关度:{doc.metadata['score']}\n\n"""
print
(
resp
[
"result"
])
for
inum
,
doc
in
enumerate
(
resp
[
"source_documents"
])]
print
(
"
\n\n
"
+
"
\n\n
"
.
join
(
source_text
))
models/chatglm_llm.py
浏览文件 @
b4aefca5
...
@@ -5,7 +5,8 @@ from langchain.llms.utils import enforce_stop_tokens
...
@@ -5,7 +5,8 @@ from langchain.llms.utils import enforce_stop_tokens
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
torch
import
torch
from
configs.model_config
import
LLM_DEVICE
from
configs.model_config
import
LLM_DEVICE
from
langchain.callbacks.base
import
CallbackManager
from
langchain.callbacks.streaming_stdout
import
StreamingStdOutCallbackHandler
from
typing
import
Dict
,
Tuple
,
Union
,
Optional
from
typing
import
Dict
,
Tuple
,
Union
,
Optional
DEVICE
=
LLM_DEVICE
DEVICE
=
LLM_DEVICE
...
@@ -54,10 +55,12 @@ class ChatGLM(LLM):
...
@@ -54,10 +55,12 @@ class ChatGLM(LLM):
max_token
:
int
=
10000
max_token
:
int
=
10000
temperature
:
float
=
0.01
temperature
:
float
=
0.01
top_p
=
0.9
top_p
=
0.9
history
=
[]
#
history = []
tokenizer
:
object
=
None
tokenizer
:
object
=
None
model
:
object
=
None
model
:
object
=
None
history_len
:
int
=
10
history_len
:
int
=
10
streaming
:
bool
=
True
callback_manager
=
CallbackManager
([
StreamingStdOutCallbackHandler
()])
def
__init__
(
self
):
def
__init__
(
self
):
super
()
.
__init__
()
super
()
.
__init__
()
...
@@ -68,46 +71,45 @@ class ChatGLM(LLM):
...
@@ -68,46 +71,45 @@ class ChatGLM(LLM):
def
_call
(
self
,
def
_call
(
self
,
prompt
:
str
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
history
:
List
[
List
[
str
]]
=
[]
,
st
ream
=
Tru
e
)
->
str
:
st
op
:
Optional
[
List
[
str
]]
=
Non
e
)
->
str
:
if
s
tream
:
if
s
elf
.
streaming
:
self
.
history
=
self
.
history
+
[[
None
,
""
]]
history
=
history
+
[[
None
,
""
]]
for
response
,
history
in
self
.
model
.
stream_chat
(
for
stream_resp
,
history
in
self
.
model
.
stream_chat
(
self
.
tokenizer
,
self
.
tokenizer
,
prompt
,
prompt
,
history
=
self
.
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
,
temperature
=
self
.
temperature
,
):
):
torch_gc
()
yield
stream_resp
,
history
self
.
history
[
-
1
][
-
1
]
=
response
yield
response
else
:
else
:
response
,
_
=
self
.
model
.
chat
(
response
,
_
=
self
.
model
.
chat
(
self
.
tokenizer
,
self
.
tokenizer
,
prompt
,
prompt
,
history
=
self
.
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
history
=
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
max_length
=
self
.
max_token
,
max_length
=
self
.
max_token
,
temperature
=
self
.
temperature
,
temperature
=
self
.
temperature
,
)
)
torch_gc
()
torch_gc
()
if
stop
is
not
None
:
if
stop
is
not
None
:
response
=
enforce_stop_tokens
(
response
,
stop
)
response
=
enforce_stop_tokens
(
response
,
stop
)
self
.
history
=
self
.
history
+
[[
None
,
response
]]
history
=
history
+
[[
None
,
response
]]
return
response
return
response
,
history
def
chat
(
self
,
#
def chat(self,
prompt
:
str
)
->
str
:
#
prompt: str) -> str:
response
,
_
=
self
.
model
.
chat
(
#
response, _ = self.model.chat(
self
.
tokenizer
,
#
self.tokenizer,
prompt
,
#
prompt,
history
=
self
.
history
[
-
self
.
history_len
:]
if
self
.
history_len
>
0
else
[],
#
history=self.history[-self.history_len:] if self.history_len > 0 else [],
max_length
=
self
.
max_token
,
#
max_length=self.max_token,
temperature
=
self
.
temperature
,
#
temperature=self.temperature,
)
#
)
torch_gc
()
#
torch_gc()
self
.
history
=
self
.
history
+
[[
None
,
response
]]
#
self.history = self.history + [[None, response]]
return
response
#
return response
def
load_model
(
self
,
def
load_model
(
self
,
model_name_or_path
:
str
=
"THUDM/chatglm-6b"
,
model_name_or_path
:
str
=
"THUDM/chatglm-6b"
,
...
@@ -149,7 +151,13 @@ class ChatGLM(LLM):
...
@@ -149,7 +151,13 @@ class ChatGLM(LLM):
else
:
else
:
from
accelerate
import
dispatch_model
from
accelerate
import
dispatch_model
model
=
AutoModel
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
,
**
kwargs
)
.
half
()
model
=
(
AutoModel
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
,
config
=
model_config
,
**
kwargs
)
.
half
())
# 可传入device_map自定义每张卡的部署情况
# 可传入device_map自定义每张卡的部署情况
if
device_map
is
None
:
if
device_map
is
None
:
device_map
=
auto_configure_device_map
(
num_gpus
)
device_map
=
auto_configure_device_map
(
num_gpus
)
...
@@ -160,7 +168,8 @@ class ChatGLM(LLM):
...
@@ -160,7 +168,8 @@ class ChatGLM(LLM):
AutoModel
.
from_pretrained
(
AutoModel
.
from_pretrained
(
model_name_or_path
,
model_name_or_path
,
config
=
model_config
,
config
=
model_config
,
trust_remote_code
=
True
)
trust_remote_code
=
True
,
**
kwargs
)
.
float
()
.
float
()
.
to
(
llm_device
)
.
to
(
llm_device
)
)
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论