Unverified 提交 0d9db37f 作者: shrimp 提交者: GitHub

完善API接口,完善模型加载 (#247)

* 完善知识库路径问题,完善api接口

统一webui、API接口知识库路径,后续路径如下:
知识库路经就是:/项目代码文件夹/vector_store/'知识库名字'
文件存放路经:/项目代码文件夹/content/'知识库名字'

修复通过api接口创建知识库的BUG,完善API接口功能。

* Update model_config.py


* 完善知识库路径问题,完善api接口 (#245) (#246)

* Fix 知识库无法上载,NLTK_DATA_PATH路径错误 (#236)

* Update chatglm_llm.py (#242)

* 完善知识库路径问题,完善api接口

统一webui、API接口知识库路径,后续路径如下:
知识库路经就是:/项目代码文件夹/vector_store/'知识库名字'
文件存放路经:/项目代码文件夹/content/'知识库名字'

修复通过api接口创建知识库的BUG,完善API接口功能。

* Update model_config.py

---------

Co-authored-by: shrimp <411161555@qq.com>
Co-authored-by: Bob Chang <bob-chang@outlook.com>

* 优化API接口,完善模型top_p参数

优化API接口,知识库非必须选项。
完善模型top_p参数

* 完善API接口,完善模型加载

API接口知识库非必须加载项
完善模型top_p参数。

---------

Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
Co-authored-by: Bob Chang <bob-chang@outlook.com>
上级 6ac8f73a
...@@ -170,32 +170,36 @@ async def delete_docs( ...@@ -170,32 +170,36 @@ async def delete_docs(
async def chat( async def chat(
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), knowledge_base_id: str = Body(..., description="知识库名字", example="kb1"),
question: str = Body(..., description="Question", example="工伤保险是什么?"), question: str = Body(..., description="问题", example="工伤保险是什么?"),
history: List[List[str]] = Body( history: List[List[str]] = Body(
[], [],
description="History of previous questions and answers", description="问题及答案的历史记录",
example=[ example=[
[ [
"工伤保险是什么?", "这里是问题,如:工伤保险是什么?",
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", "答案:工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
] ]
], ],
), ),
): ):
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
if not os.path.exists(vs_path): resp = {}
raise ValueError(f"Knowledge base {knowledge_base_id} not found") if os.path.exists(vs_path) and knowledge_base_id:
for resp, history in local_doc_qa.get_knowledge_based_answer(
for resp, history in local_doc_qa.get_knowledge_based_answer( query=question, vs_path=vs_path, chat_history=history, streaming=False
query=question, vs_path=vs_path, chat_history=history, streaming=True ):
): pass
pass source_documents = [
source_documents = [ f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" f"""相关度:{doc.metadata['score']}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n""" for inum, doc in enumerate(resp["source_documents"])
for inum, doc in enumerate(resp["source_documents"]) ]
] else:
for resp_s, history in local_doc_qa.llm._call(prompt=question, history=history, streaming=False):
pass
resp["result"] = resp_s
source_documents =[("当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。")]
return ChatMessage( return ChatMessage(
question=question, question=question,
......
...@@ -43,7 +43,7 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: ...@@ -43,7 +43,7 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
class ChatGLM(LLM): class ChatGLM(LLM):
max_token: int = 10000 max_token: int = 10000
temperature: float = 0.01 temperature: float = 0.8
top_p = 0.9 top_p = 0.9
# history = [] # history = []
tokenizer: object = None tokenizer: object = None
...@@ -68,6 +68,7 @@ class ChatGLM(LLM): ...@@ -68,6 +68,7 @@ class ChatGLM(LLM):
history=history[-self.history_len:-1] if self.history_len > 0 else [], history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
top_p=self.top_p,
)): )):
torch_gc() torch_gc()
if inum == 0: if inum == 0:
...@@ -83,6 +84,7 @@ class ChatGLM(LLM): ...@@ -83,6 +84,7 @@ class ChatGLM(LLM):
history=history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
top_p=self.top_p,
) )
torch_gc() torch_gc()
history += [[prompt, response]] history += [[prompt, response]]
...@@ -141,7 +143,7 @@ class ChatGLM(LLM): ...@@ -141,7 +143,7 @@ class ChatGLM(LLM):
from accelerate import dispatch_model from accelerate import dispatch_model
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
config=model_config, **kwargs) config=model_config, **kwargs)
if LLM_LORA_PATH and use_lora: if LLM_LORA_PATH and use_lora:
from peft import PeftModel from peft import PeftModel
model = PeftModel.from_pretrained(model, LLM_LORA_PATH) model = PeftModel.from_pretrained(model, LLM_LORA_PATH)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论