提交 c389f1a3 作者: glide-the

增加fastchat打字机输出

上级 5cbb86a8
...@@ -6,6 +6,7 @@ from queue import Queue ...@@ -6,6 +6,7 @@ from queue import Queue
from threading import Thread from threading import Thread
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from pydantic import BaseModel
import torch import torch
import transformers import transformers
...@@ -23,13 +24,12 @@ class ListenerToken: ...@@ -23,13 +24,12 @@ class ListenerToken:
self._scores = _scores self._scores = _scores
class AnswerResult: class AnswerResult(BaseModel):
""" """
消息实体 消息实体
""" """
history: List[List[str]] = [] history: List[List[str]] = []
llm_output: Optional[dict] = None llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream: class AnswerResultStream:
...@@ -167,8 +167,6 @@ class BaseAnswer(ABC): ...@@ -167,8 +167,6 @@ class BaseAnswer(ABC):
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator: with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
for answerResult in generator: for answerResult in generator:
if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult yield answerResult
@abstractmethod @abstractmethod
......
...@@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC): ...@@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": stream_resp} answer_result.llm_output = {"answer": stream_resp}
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result) generate_with_callback(answer_result)
self.checkPoint.clear_torch_cache() self.checkPoint.clear_torch_cache()
else: else:
...@@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC): ...@@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result) generate_with_callback(answer_result)
from abc import ABC from abc import ABC
from langchain.chains.base import Chain from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Collection from typing import (
Any, Dict, List, Optional, Generator, Collection, Set,
Callable,
Tuple,
Union)
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
...@@ -8,9 +13,26 @@ from models.base import (BaseAnswer, ...@@ -8,9 +13,26 @@ from models.base import (BaseAnswer,
AnswerResult, AnswerResult,
AnswerResultStream, AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue) AnswerResultQueueSentinelTokenListenerQueue)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from pydantic import Extra, Field, root_validator
from openai import (
ChatCompletion
)
import openai
import logging
import torch import torch
import transformers import transformers
logger = logging.getLogger(__name__)
def _build_message_template() -> Dict[str, str]: def _build_message_template() -> Dict[str, str]:
""" """
...@@ -25,12 +47,18 @@ def _build_message_template() -> Dict[str, str]: ...@@ -25,12 +47,18 @@ def _build_message_template() -> Dict[str, str]:
# 将历史对话数组转换为文本格式 # 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]: def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = [] build_messages: Collection[Dict[str, str]] = []
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = "You are a helpful assistant."
build_messages.append(system_build_message)
for i, (old_query, response) in enumerate(history): for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template() user_build_message = _build_message_template()
user_build_message['role'] = 'user' user_build_message['role'] = 'user'
user_build_message['content'] = old_query user_build_message['content'] = old_query
system_build_message = _build_message_template() system_build_message = _build_message_template()
system_build_message['role'] = 'system' system_build_message['role'] = 'assistant'
system_build_message['content'] = response system_build_message['content'] = response
build_messages.append(user_build_message) build_messages.append(user_build_message)
build_messages.append(system_build_message) build_messages.append(system_build_message)
...@@ -43,6 +71,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, ...@@ -43,6 +71,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
client: Any
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 6
api_base_url: str = "http://localhost:8000/v1" api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b" model_name: str = "chatglm-6b"
max_token: int = 10000 max_token: int = 10000
...@@ -108,6 +139,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): ...@@ -108,6 +139,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
def call_model_name(self, model_name): def call_model_name(self, model_name):
self.model_name = model_name self.model_name = model_name
def _create_retry_decorator(self) -> Callable[[Any], Any]:
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: Dict[str, Any],
...@@ -124,29 +184,70 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): ...@@ -124,29 +184,70 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
history = inputs[self.history_key] history = inputs[self.history_key]
streaming = inputs[self.streaming_key] streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key] prompt = inputs[self.prompt_key]
stop = inputs['stop']
print(f"__call:{prompt}") print(f"__call:{prompt}")
try: try:
import openai
# Not support yet # Not support yet
# openai.api_key = "EMPTY" # openai.api_key = "EMPTY"
openai.api_key = self.api_key openai.api_key = self.api_key
openai.api_base = self.api_base_url openai.api_base = self.api_base_url
except ImportError: self.client = openai.ChatCompletion
except AttributeError:
raise ValueError( raise ValueError(
"Could not import openai python package. " "`openai` has no `ChatCompletion` attribute, this is likely "
"Please install it with `pip install openai`." "due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
) )
# create a chat completion msg = build_message_list(prompt, history=history)
completion = openai.ChatCompletion.create(
model=self.model_name, if streaming:
messages=build_message_list(prompt) params = {"stream": streaming,
) "model": self.model_name,
print(f"response:{completion.choices[0].message.content}") "stop": stop}
print(f"+++++++++++++++++++++++++++++++++++") for stream_resp in self.completion_with_retry(
messages=msg,
history += [[prompt, completion.choices[0].message.content]] **params
answer_result = AnswerResult() ):
answer_result.history = history role = stream_resp["choices"][0]["delta"].get("role", "")
answer_result.llm_output = {"answer": completion.choices[0].message.content} token = stream_resp["choices"][0]["delta"].get("content", "")
generate_with_callback(answer_result) history += [[prompt, token]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": token}
generate_with_callback(answer_result)
else:
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
response = self.completion_with_retry(
messages=msg,
**params
)
role = response["choices"][0]["message"].get("role", "")
content = response["choices"][0]["message"].get("content", "")
history += [[prompt, content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": content}
generate_with_callback(answer_result)
if __name__ == "__main__":
chain = FastChatOpenAILLMChain()
chain.set_api_key("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")
chain.set_api_base_url("https://api.openai.com/v1")
chain.call_model_name("gpt-3.5-turbo")
answer_result_stream_result = chain({"streaming": False,
"stop": "",
"prompt": "你好",
"history": []
})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
print(resp)
...@@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC): ...@@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult() answer_result = AnswerResult()
history += [[prompt, reply]] history += [[prompt, reply]]
answer_result.history = history answer_result.history = history
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": reply} answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result) generate_with_callback(answer_result)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论