提交 22d08f5e 作者: glide-the

必要参数校验

上级 1e2124ff
......@@ -16,7 +16,7 @@ embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec-base": "shibing624/text2vec-base-chinese",
"text2vec": "GanymedeNil/text2vec-large-chinese",
"text2vec": "/media/checkpoint/text2vec-large-chinese/",
"m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base",
}
......@@ -186,7 +186,7 @@ llm_model_dict = {
}
# LLM 名称
LLM_MODEL = "chatglm-6b"
LLM_MODEL = "fastchat-chatglm-6b"
# 量化加载8bit 模型
LOAD_IN_8BIT = False
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
......
......@@ -52,16 +52,21 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
system_build_message['role'] = 'system'
system_build_message['content'] = "You are a helpful assistant."
build_messages.append(system_build_message)
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'assistant'
system_build_message['content'] = response
build_messages.append(user_build_message)
build_messages.append(system_build_message)
if history:
for i, (user, assistant) in enumerate(history):
if user:
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = user
build_messages.append(user_build_message)
if not assistant:
raise RuntimeError("历史数据结构不正确")
system_build_message = _build_message_template()
system_build_message['role'] = 'assistant'
system_build_message['content'] = assistant
build_messages.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
......@@ -181,10 +186,10 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
history = inputs.get(self.history_key, [])
streaming = inputs.get(self.streaming_key, False)
prompt = inputs[self.prompt_key]
stop = inputs.get("stop", None)
stop = inputs.get("stop", "stop")
print(f"__call:{prompt}")
try:
......@@ -205,16 +210,18 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
out_str = ""
for stream_resp in self.completion_with_retry(
messages=msg,
**params
):
role = stream_resp["choices"][0]["delta"].get("role", "")
token = stream_resp["choices"][0]["delta"].get("content", "")
history += [[prompt, token]]
out_str += token
history[-1] = [prompt, out_str]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": token}
answer_result.llm_output = {"answer": out_str}
generate_with_callback(answer_result)
else:
......@@ -239,10 +246,10 @@ if __name__ == "__main__":
chain = FastChatOpenAILLMChain()
chain.set_api_key("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")
chain.set_api_base_url("https://api.openai.com/v1")
chain.call_model_name("gpt-3.5-turbo")
# chain.set_api_base_url("https://api.openai.com/v1")
# chain.call_model_name("gpt-3.5-turbo")
answer_result_stream_result = chain({"streaming": False,
answer_result_stream_result = chain({"streaming": True,
"prompt": "你好",
"history": []
})
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论