提交 22d08f5e 作者: glide-the

必要参数校验

上级 1e2124ff
...@@ -16,7 +16,7 @@ embedding_model_dict = { ...@@ -16,7 +16,7 @@ embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec-base": "shibing624/text2vec-base-chinese", "text2vec-base": "shibing624/text2vec-base-chinese",
"text2vec": "GanymedeNil/text2vec-large-chinese", "text2vec": "/media/checkpoint/text2vec-large-chinese/",
"m3e-small": "moka-ai/m3e-small", "m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base", "m3e-base": "moka-ai/m3e-base",
} }
...@@ -186,7 +186,7 @@ llm_model_dict = { ...@@ -186,7 +186,7 @@ llm_model_dict = {
} }
# LLM 名称 # LLM 名称
LLM_MODEL = "chatglm-6b" LLM_MODEL = "fastchat-chatglm-6b"
# 量化加载8bit 模型 # 量化加载8bit 模型
LOAD_IN_8BIT = False LOAD_IN_8BIT = False
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
......
...@@ -52,15 +52,20 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, ...@@ -52,15 +52,20 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
system_build_message['role'] = 'system' system_build_message['role'] = 'system'
system_build_message['content'] = "You are a helpful assistant." system_build_message['content'] = "You are a helpful assistant."
build_messages.append(system_build_message) build_messages.append(system_build_message)
if history:
for i, (user, assistant) in enumerate(history):
if user:
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template() user_build_message = _build_message_template()
user_build_message['role'] = 'user' user_build_message['role'] = 'user'
user_build_message['content'] = old_query user_build_message['content'] = user
build_messages.append(user_build_message)
if not assistant:
raise RuntimeError("历史数据结构不正确")
system_build_message = _build_message_template() system_build_message = _build_message_template()
system_build_message['role'] = 'assistant' system_build_message['role'] = 'assistant'
system_build_message['content'] = response system_build_message['content'] = assistant
build_messages.append(user_build_message)
build_messages.append(system_build_message) build_messages.append(system_build_message)
user_build_message = _build_message_template() user_build_message = _build_message_template()
...@@ -181,10 +186,10 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): ...@@ -181,10 +186,10 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None: generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key] history = inputs.get(self.history_key, [])
streaming = inputs[self.streaming_key] streaming = inputs.get(self.streaming_key, False)
prompt = inputs[self.prompt_key] prompt = inputs[self.prompt_key]
stop = inputs.get("stop", None) stop = inputs.get("stop", "stop")
print(f"__call:{prompt}") print(f"__call:{prompt}")
try: try:
...@@ -205,16 +210,18 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): ...@@ -205,16 +210,18 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
params = {"stream": streaming, params = {"stream": streaming,
"model": self.model_name, "model": self.model_name,
"stop": stop} "stop": stop}
out_str = ""
for stream_resp in self.completion_with_retry( for stream_resp in self.completion_with_retry(
messages=msg, messages=msg,
**params **params
): ):
role = stream_resp["choices"][0]["delta"].get("role", "") role = stream_resp["choices"][0]["delta"].get("role", "")
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
history += [[prompt, token]] out_str += token
history[-1] = [prompt, out_str]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": token} answer_result.llm_output = {"answer": out_str}
generate_with_callback(answer_result) generate_with_callback(answer_result)
else: else:
...@@ -239,10 +246,10 @@ if __name__ == "__main__": ...@@ -239,10 +246,10 @@ if __name__ == "__main__":
chain = FastChatOpenAILLMChain() chain = FastChatOpenAILLMChain()
chain.set_api_key("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") chain.set_api_key("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")
chain.set_api_base_url("https://api.openai.com/v1") # chain.set_api_base_url("https://api.openai.com/v1")
chain.call_model_name("gpt-3.5-turbo") # chain.call_model_name("gpt-3.5-turbo")
answer_result_stream_result = chain({"streaming": False, answer_result_stream_result = chain({"streaming": True,
"prompt": "你好", "prompt": "你好",
"history": [] "history": []
}) })
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论