提交 c389f1a3 作者: glide-the

增加fastchat打字机输出

上级 5cbb86a8
......@@ -6,6 +6,7 @@ from queue import Queue
from threading import Thread
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.loader import LoaderCheckPoint
from pydantic import BaseModel
import torch
import transformers
......@@ -23,13 +24,12 @@ class ListenerToken:
self._scores = _scores
class AnswerResult:
class AnswerResult(BaseModel):
"""
消息实体
"""
history: List[List[str]] = []
llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream:
......@@ -167,8 +167,6 @@ class BaseAnswer(ABC):
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
for answerResult in generator:
if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult
@abstractmethod
......
......@@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": stream_resp}
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
self.checkPoint.clear_torch_cache()
else:
......@@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
from abc import ABC
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Collection
from typing import (
Any, Dict, List, Optional, Generator, Collection, Set,
Callable,
Tuple,
Union)
from models.loader import LoaderCheckPoint
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.base import (BaseAnswer,
......@@ -8,9 +13,26 @@ from models.base import (BaseAnswer,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from pydantic import Extra, Field, root_validator
from openai import (
ChatCompletion
)
import openai
import logging
import torch
import transformers
logger = logging.getLogger(__name__)
def _build_message_template() -> Dict[str, str]:
"""
......@@ -25,12 +47,18 @@ def _build_message_template() -> Dict[str, str]:
# 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = []
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = "You are a helpful assistant."
build_messages.append(system_build_message)
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['role'] = 'assistant'
system_build_message['content'] = response
build_messages.append(user_build_message)
build_messages.append(system_build_message)
......@@ -43,6 +71,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
client: Any
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 6
api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b"
max_token: int = 10000
......@@ -108,6 +139,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
def call_model_name(self, model_name):
self.model_name = model_name
def _create_retry_decorator(self) -> Callable[[Any], Any]:
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _call(
self,
inputs: Dict[str, Any],
......@@ -124,29 +184,70 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
stop = inputs['stop']
print(f"__call:{prompt}")
try:
import openai
# Not support yet
# openai.api_key = "EMPTY"
openai.api_key = self.api_key
openai.api_base = self.api_base_url
except ImportError:
self.client = openai.ChatCompletion
except AttributeError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=build_message_list(prompt)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
history += [[prompt, completion.choices[0].message.content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": completion.choices[0].message.content}
generate_with_callback(answer_result)
msg = build_message_list(prompt, history=history)
if streaming:
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
for stream_resp in self.completion_with_retry(
messages=msg,
**params
):
role = stream_resp["choices"][0]["delta"].get("role", "")
token = stream_resp["choices"][0]["delta"].get("content", "")
history += [[prompt, token]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": token}
generate_with_callback(answer_result)
else:
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
response = self.completion_with_retry(
messages=msg,
**params
)
role = response["choices"][0]["message"].get("role", "")
content = response["choices"][0]["message"].get("content", "")
history += [[prompt, content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": content}
generate_with_callback(answer_result)
if __name__ == "__main__":
chain = FastChatOpenAILLMChain()
chain.set_api_key("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")
chain.set_api_base_url("https://api.openai.com/v1")
chain.call_model_name("gpt-3.5-turbo")
answer_result_stream_result = chain({"streaming": False,
"stop": "",
"prompt": "你好",
"history": []
})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
print(resp)
......@@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
history += [[prompt, reply]]
answer_result.history = history
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论