提交 24324563 作者: glide-the

适配远程LLM调用

上级 99e9d1d7
...@@ -69,12 +69,23 @@ llm_model_dict = { ...@@ -69,12 +69,23 @@ llm_model_dict = {
"local_model_path": None, "local_model_path": None,
"provides": "LLamaLLM" "provides": "LLamaLLM"
}, },
"fastChatOpenAI": { "fast-chat-chatglm-6b": {
"name": "FastChatOpenAI", "name": "FastChatOpenAI",
"pretrained_model_name": "FastChatOpenAI", "pretrained_model_name": "FastChatOpenAI",
"local_model_path": None, "local_model_path": None,
"provides": "FastChatOpenAILLM" "provides": "FastChatOpenAILLM",
} "api_base_url": "http://localhost:8000/v1",
"model_name": "chatglm-6b"
},
"fast-chat-vicuna-13b-hf": {
"name": "FastChatOpenAI",
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "FastChatOpenAILLM",
"api_base_url": "http://localhost:8000/v1",
"model_name": "vicuna-13b-hf"
},
} }
# LLM 名称 # LLM 名称
......
...@@ -111,9 +111,9 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -111,9 +111,9 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
messages=self.build_message_list(prompt) messages=self.build_message_list(prompt)
) )
self.history += [[prompt, completion.choices[0].message.content]] history += [[prompt, completion.choices[0].message.content]]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = self.history answer_result.history = history
answer_result.llm_output = {"answer": completion.choices[0].message.content} answer_result.llm_output = {"answer": completion.choices[0].message.content}
yield answer_result yield answer_result
...@@ -34,11 +34,14 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ ...@@ -34,11 +34,14 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
loaderCheckPoint.model_path = llm_model_info["local_model_path"] loaderCheckPoint.model_path = llm_model_info["local_model_path"]
if 'FastChat' in loaderCheckPoint.model_name: if 'FastChatOpenAILLM' in llm_model_info["local_model_path"]:
loaderCheckPoint.unload_model() loaderCheckPoint.unload_model()
else: else:
loaderCheckPoint.reload_model() loaderCheckPoint.reload_model()
provides_class = getattr(sys.modules['models'], llm_model_info['provides']) provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
modelInsLLM = provides_class(checkPoint=loaderCheckPoint) modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
if 'FastChatOpenAILLM' in llm_model_info["provides"]:
modelInsLLM.set_api_base_url(llm_model_info['api_base_url'])
modelInsLLM.call_model_name(llm_model_info['model_name'])
return modelInsLLM return modelInsLLM
...@@ -18,14 +18,13 @@ async def dispatch(args: Namespace): ...@@ -18,14 +18,13 @@ async def dispatch(args: Namespace):
shared.loaderCheckPoint = LoaderCheckPoint(args_dict) shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM() llm_model_ins = shared.loaderLLM()
llm_model_ins.set_api_base_url("http://localhost:8000/v1")
llm_model_ins.call_model_name("chatglm-6b")
history = [ history = [
("which city is this?", "tokyo"), ("which city is this?", "tokyo"),
("why?", "she's japanese"), ("why?", "she's japanese"),
] ]
for answer_result in llm_model_ins.generatorAnswer(prompt="她在做什么? ", history=history, for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
streaming=False): streaming=False):
resp = answer_result.llm_output["answer"] resp = answer_result.llm_output["answer"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论