Unverified 提交 453702db 作者: Zhi-guo Huang 提交者: GitHub

Merge branch 'imClumsyPanda:dev' into dev

...@@ -205,11 +205,13 @@ Web UI 可以实现如下功能: ...@@ -205,11 +205,13 @@ Web UI 可以实现如下功能:
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) - [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft) - [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm - [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
- [ ] 增加更多 Embedding 模型支持 - [x] 增加更多 Embedding 模型支持
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh) - [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [x] [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) - [x] [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
- [x] [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) - [x] [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese)
- [x] [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese) - [x] [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [x] [moka-ai/m3e-small](https://huggingface.co/moka-ai/m3e-small)
- [x] [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base)
- [ ] Web UI - [ ] Web UI
- [x] 基于 gradio 实现 Web UI DEMO - [x] 基于 gradio 实现 Web UI DEMO
- [x] 基于 streamlit 实现 Web UI DEMO - [x] 基于 streamlit 实现 Web UI DEMO
...@@ -227,6 +229,6 @@ Web UI 可以实现如下功能: ...@@ -227,6 +229,6 @@ Web UI 可以实现如下功能:
- [x] VUE 前端 - [x] VUE 前端
## 项目交流群 ## 项目交流群
![二维码](img/qr_code_28.jpg) ![二维码](img/qr_code_30.jpg)
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
{ ++ /dev/null
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d2ff171c-f5f8-4590-9ce0-21c87e3d5b39",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO 2023-06-01 20:26:48,576-1d: \n",
"loading model config\n",
"llm device: cuda\n",
"embedding device: cuda\n",
"dir: /media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM\n",
"flagging username: 7daba79785044bceb6896b9e6f8f9894\n",
"\n"
]
}
],
"source": [
"import sys\n",
"sys.path.append('/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/')\n",
"from langchain.llms.base import LLM\n",
"import torch\n",
"import transformers \n",
"import models.shared as shared \n",
"from abc import ABC\n",
"\n",
"from langchain.llms.base import LLM\n",
"import random\n",
"from transformers.generation.logits_process import LogitsProcessor\n",
"from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList\n",
"from typing import Optional, List, Dict, Any\n",
"from models.loader import LoaderCheckPoint \n",
"from models.base import (BaseAnswer,\n",
" AnswerResult)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "68978c38-c0e9-4ae9-ba90-9c02aca335be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading vicuna-7b-hf...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning.\n",
"/media/gpt4-pdf-chatbot-langchain/pyenv-langchain/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /media/gpt4-pdf-chatbot-langchain/pyenv-langchain did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...\n",
" warn(msg)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"===================================BUG REPORT===================================\n",
"Welcome to bitsandbytes. For bug reports, please run\n",
"\n",
"python -m bitsandbytes\n",
"\n",
" and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
"================================================================================\n",
"bin /media/gpt4-pdf-chatbot-langchain/pyenv-langchain/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so\n",
"CUDA SETUP: CUDA runtime path found: /opt/cuda/lib64/libcudart.so.11.0\n",
"CUDA SETUP: Highest compute capability among GPUs detected: 8.6\n",
"CUDA SETUP: Detected CUDA version 118\n",
"CUDA SETUP: Loading binary /media/gpt4-pdf-chatbot-langchain/pyenv-langchain/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9b61d05e18044b009c72b862c84ab5cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded the model in 6.39 seconds.\n"
]
}
],
"source": [
"import asyncio\n",
"from argparse import Namespace\n",
"from models.loader.args import parser\n",
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
" \n",
"args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-7b-hf', '--no-remote-model', '--load-in-8bit'])\n",
"\n",
"args_dict = vars(args)\n",
"\n",
"shared.loaderCheckPoint = LoaderCheckPoint(args_dict)\n",
"torch.cuda.empty_cache()\n",
"shared.loaderCheckPoint.unload_model()\n",
"shared.loaderCheckPoint.reload_model() \n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "c8e4a58d-1a3a-484a-8417-bcec0eb7170e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'action': 'State of Dialogue History System', 'action_input': '露ᥫᩣ,'}\n"
]
}
],
"source": [
"from jsonformer import Jsonformer\n",
"json_schema = {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"action\": {\"type\": \"string\"},\n",
" \"action_input\": {\"type\": \"string\"}\n",
" }\n",
"}\n",
"\n",
"prompt = \"\"\" Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n",
"\n",
"State of Dialogue History System: 露ᥫᩣ, 张毛峰的对话 - 当需要查找露ᥫᩣ, 张毛峰之间的聊天内容时,这里面的回答是很有用的。输入应该是一个完整的问题。\n",
"Summary: useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary.\n",
"\n",
"The way you use the tools is by specifying a json blob.\n",
"Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).\n",
"\n",
"The only values that should be in the \"action\" field are: {tool_names}\n",
"\n",
"The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:\n",
"\n",
"```\n",
"{{{{\n",
" \"action\": $TOOL_NAME,\n",
" \"action_input\": $INPUT\n",
"}}}}\n",
"```\n",
"\n",
"ALWAYS use the following format:\n",
"\n",
"Question: the input question you must answer\n",
"Thought: you should always think about what to do\n",
"Action:\n",
"```\n",
"$JSON_BLOB\n",
"```\n",
"Observation: the result of the action\n",
"... (this Thought/Action/Observation can repeat N times)\n",
"Thought: I now know the final answer\n",
"Final Answer: the final answer to the original input question\n",
"\n",
"Begin! Reminder to always use the exact characters `Final Answer` when responding.\n",
"\n",
"Question: 我想查看关于露露的摘要信息\n",
"\"\"\"\n",
"jsonformer = Jsonformer(shared.loaderCheckPoint.model, shared.loaderCheckPoint.tokenizer, json_schema, prompt)\n",
"generated_data = jsonformer()\n",
"\n",
"print(generated_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e089a828-b662-4d9a-8d88-4bf95ccadbab",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
from langchain.agents import Tool
from langchain.tools import BaseTool
from langchain import PromptTemplate, LLMChain
from agent.custom_search import DeepSearch
from langchain.agents import BaseSingleActionAgent, AgentOutputParser, LLMSingleActionAgent, AgentExecutor
from typing import List, Tuple, Any, Union, Optional, Type
from langchain.schema import AgentAction, AgentFinish
from langchain.prompts import StringPromptTemplate
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.base_language import BaseLanguageModel
import re
agent_template = """
你现在是一个{role}。这里是一些已知信息:
{related_content}
{background_infomation}
{question_guide}:{input}
{answer_format}
"""
class CustomPromptTemplate(StringPromptTemplate):
template: str
tools: List[Tool]
def format(self, **kwargs) -> str:
intermediate_steps = kwargs.pop("intermediate_steps")
# 没有互联网查询信息
if len(intermediate_steps) == 0:
background_infomation = "\n"
role = "傻瓜机器人"
question_guide = "我现在有一个问题"
answer_format = "如果你知道答案,请直接给出你的回答!如果你不知道答案,请你只回答\"DeepSearch('搜索词')\",并将'搜索词'替换为你认为需要搜索的关键词,除此之外不要回答其他任何内容。\n\n下面请回答我上面提出的问题!"
# 返回了背景信息
else:
# 根据 intermediate_steps 中的 AgentAction 拼装 background_infomation
background_infomation = "\n\n你还有这些已知信息作为参考:\n\n"
action, observation = intermediate_steps[0]
background_infomation += f"{observation}\n"
role = "聪明的 AI 助手"
question_guide = "请根据这些已知信息回答我的问题"
answer_format = ""
kwargs["background_infomation"] = background_infomation
kwargs["role"] = role
kwargs["question_guide"] = question_guide
kwargs["answer_format"] = answer_format
return self.template.format(**kwargs)
class CustomSearchTool(BaseTool):
name: str = "DeepSearch"
description: str = ""
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None):
return DeepSearch.search(query = query)
async def _arun(self, query: str):
raise NotImplementedError("DeepSearch does not support async")
class CustomAgent(BaseSingleActionAgent):
@property
def input_keys(self):
return ["input"]
def plan(self, intermedate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any) -> Union[AgentAction, AgentFinish]:
return AgentAction(tool="DeepSearch", tool_input=kwargs["input"], log="")
class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
# group1 = 调用函数名字
# group2 = 传入参数
match = re.match(r'^[\s\w]*(DeepSearch)\(([^\)]+)\)', llm_output, re.DOTALL)
print(match)
# 如果 llm 没有返回 DeepSearch() 则认为直接结束指令
if not match:
return AgentFinish(
return_values={"output": llm_output.strip()},
log=llm_output,
)
# 否则的话都认为需要调用 Tool
else:
action = match.group(1).strip()
action_input = match.group(2).strip()
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
class DeepAgent:
tool_name: str = "DeepSearch"
agent_executor: any
tools: List[Tool]
llm_chain: any
def query(self, related_content: str = "", query: str = ""):
tool_name = self.tool_name
result = self.agent_executor.run(related_content=related_content, input=query ,tool_name=self.tool_name)
return result
def __init__(self, llm: BaseLanguageModel, **kwargs):
tools = [
Tool.from_function(
func=DeepSearch.search,
name="DeepSearch",
description=""
)
]
self.tools = tools
tool_names = [tool.name for tool in tools]
output_parser = CustomOutputParser()
prompt = CustomPromptTemplate(template=agent_template,
tools=tools,
input_variables=["related_content","tool_name", "input", "intermediate_steps"])
llm_chain = LLMChain(llm=llm, prompt=prompt)
self.llm_chain = llm_chain
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:"],
allowed_tools=tool_names
)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
self.agent_executor = agent_executor
import requests
RapidAPIKey = "90bbe925ebmsh1c015166fc5e12cp14c503jsn6cca55551ae4"
class DeepSearch:
def search(query: str = ""):
query = query.strip()
if query == "":
return ""
if RapidAPIKey == "":
return "请配置你的 RapidAPIKey"
url = "https://bing-web-search1.p.rapidapi.com/search"
querystring = {"q": query,
"mkt":"zh-cn","textDecorations":"false","setLang":"CN","safeSearch":"Off","textFormat":"Raw"}
headers = {
"Accept": "application/json",
"X-BingApis-SDK": "true",
"X-RapidAPI-Key": RapidAPIKey,
"X-RapidAPI-Host": "bing-web-search1.p.rapidapi.com"
}
response = requests.get(url, headers=headers, params=querystring)
data_list = response.json()['value']
if len(data_list) == 0:
return ""
else:
result_arr = []
result_str = ""
count_index = 0
for i in range(6):
item = data_list[i]
title = item["name"]
description = item["description"]
item_str = f"{title}: {description}"
result_arr = result_arr + [item_str]
result_str = "\n".join(result_arr)
return result_str
...@@ -4,9 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLo ...@@ -4,9 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLo
from configs.model_config import * from configs.model_config import *
import datetime import datetime
from textsplitter import ChineseTextSplitter from textsplitter import ChineseTextSplitter
from typing import List, Tuple, Dict from typing import List
from langchain.docstore.document import Document
import numpy as np
from utils import torch_gc from utils import torch_gc
from tqdm import tqdm from tqdm import tqdm
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
......
...@@ -9,11 +9,16 @@ logger = logging.getLogger() ...@@ -9,11 +9,16 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT) logging.basicConfig(format=LOG_FORMAT)
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径
embedding_model_dict = { embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec-base": "shibing624/text2vec-base-chinese", "text2vec-base": "shibing624/text2vec-base-chinese",
"text2vec": "GanymedeNil/text2vec-large-chinese", "text2vec": "GanymedeNil/text2vec-large-chinese",
"m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base",
} }
# Embedding model name # Embedding model name
...@@ -25,6 +30,9 @@ EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backe ...@@ -25,6 +30,9 @@ EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backe
# supported LLM models # supported LLM models
# llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例 # llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
# 在以下字典中修改属性值,以指定本地 LLM 模型存储位置
# 如将 "chatglm-6b" 的 "local_model_path" 由 None 修改为 "User/Downloads/chatglm-6b"
# 此处请写绝对路径
llm_model_dict = { llm_model_dict = {
"chatglm-6b-int4-qe": { "chatglm-6b-int4-qe": {
"name": "chatglm-6b-int4-qe", "name": "chatglm-6b-int4-qe",
...@@ -66,7 +74,7 @@ llm_model_dict = { ...@@ -66,7 +74,7 @@ llm_model_dict = {
"vicuna-13b-hf": { "vicuna-13b-hf": {
"name": "vicuna-13b-hf", "name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None, "local_model_path": "/media/checkpoint/vicuna-13b-hf",
"provides": "LLamaLLM" "provides": "LLamaLLM"
}, },
......
...@@ -34,6 +34,7 @@ class ChatGLM(BaseAnswer, LLM, ABC): ...@@ -34,6 +34,7 @@ class ChatGLM(BaseAnswer, LLM, ABC):
self.history_len = history_len self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
print(f"__call:{prompt}")
response, _ = self.checkPoint.model.chat( response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer, self.checkPoint.tokenizer,
prompt, prompt,
...@@ -41,6 +42,8 @@ class ChatGLM(BaseAnswer, LLM, ABC): ...@@ -41,6 +42,8 @@ class ChatGLM(BaseAnswer, LLM, ABC):
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature temperature=self.temperature
) )
print(f"response:{response}")
print(f"+++++++++++++++++++++++++++++++++++")
return response return response
def generatorAnswer(self, prompt: str, def generatorAnswer(self, prompt: str,
......
...@@ -69,7 +69,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -69,7 +69,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
self.model_name = model_name self.model_name = model_name
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass print(f"__call:{prompt}")
try:
import openai
# Not support yet
openai.api_key = "EMPTY"
openai.api_base = self.api_base_url
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=self.build_message_list(prompt)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
return completion.choices[0].message.content
# 将历史对话数组转换为文本格式 # 将历史对话数组转换为文本格式
def build_message_list(self, query) -> Collection[Dict[str, str]]: def build_message_list(self, query) -> Collection[Dict[str, str]]:
......
...@@ -22,7 +22,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): ...@@ -22,7 +22,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
class LLamaLLM(BaseAnswer, LLM, ABC): class LLamaLLM(BaseAnswer, LLM, ABC):
checkPoint: LoaderCheckPoint = None checkPoint: LoaderCheckPoint = None
history = [] # history = []
history_len: int = 3 history_len: int = 3
max_new_tokens: int = 500 max_new_tokens: int = 500
num_beams: int = 1 num_beams: int = 1
...@@ -88,9 +88,16 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -88,9 +88,16 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
return reply return reply
# 将历史对话数组转换为文本格式 # 将历史对话数组转换为文本格式
def history_to_text(self, query): def history_to_text(self, query, history):
"""
历史对话软提示
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history
数组转换为所需的文本格式。然后,我们将格式化后的历史文本
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。
:return:
"""
formatted_history = '' formatted_history = ''
history = self.history[-self.history_len:] if self.history_len > 0 else [] history = history[-self.history_len:] if self.history_len > 0 else []
for i, (old_query, response) in enumerate(history): for i, (old_query, response) in enumerate(history):
formatted_history += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) formatted_history += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
formatted_history += "[Round {}]\n问:{}\n答:".format(len(history), query) formatted_history += "[Round {}]\n问:{}\n答:".format(len(history), query)
...@@ -116,20 +123,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -116,20 +123,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
return input_ids, position_ids, attention_mask return input_ids, position_ids, attention_mask
def generate_softprompt_history_tensors(self, query):
"""
历史对话软提示
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history
数组转换为所需的文本格式。然后,我们将格式化后的历史文本
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。
:return:
"""
# 对话内容
# 处理历史对话
formatted_history = self.history_to_text(query)
return formatted_history
@property @property
def _history_len(self) -> int: def _history_len(self) -> int:
return self.history_len return self.history_len
...@@ -173,18 +166,18 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -173,18 +166,18 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
new_tokens = len(output_ids[0]) - len(input_ids[0]) new_tokens = len(output_ids[0]) - len(input_ids[0])
reply = self.decode(output_ids[0][-new_tokens:]) reply = self.decode(output_ids[0][-new_tokens:])
print(f"response:{reply}") print(f"response:{reply}")
self.history = self.history + [[None, reply]] print(f"+++++++++++++++++++++++++++++++++++")
return reply return reply
def generatorAnswer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False): streaming: bool = False):
if history:
self.history = history
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现 # TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
softprompt = self.generate_softprompt_history_tensors(prompt) softprompt = self.history_to_text(prompt,history=history)
response = self._call(prompt=softprompt, stop=['\n###']) response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = self.history answer_result.history = history + [[None, response]]
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
yield answer_result yield answer_result
...@@ -22,27 +22,6 @@ class MyFAISS(FAISS, VectorStore): ...@@ -22,27 +22,6 @@ class MyFAISS(FAISS, VectorStore):
index_to_docstore_id=index_to_docstore_id, index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2) normalize_L2=normalize_L2)
# def similarity_search_with_score_by_vector(
# self, embedding: List[float], k: int = 4
# ) -> List[Tuple[Document, float]]:
# faiss = dependable_faiss_import()
# vector = np.array([embedding], dtype=np.float32)
# if self._normalize_L2:
# faiss.normalize_L2(vector)
# scores, indices = self.index.search(vector, k)
# docs = []
# for j, i in enumerate(indices[0]):
# if i == -1:
# # This happens when not enough docs are returned.
# continue
# _id = self.index_to_docstore_id[i]
# doc = self.docstore.search(_id)
# if not isinstance(doc, Document):
# raise ValueError(f"Could not find document for id {_id}, got {doc}")
#
# docs.append((doc, scores[0][j]))
# return docs
def seperate_list(self, ls: List[int]) -> List[List[int]]: def seperate_list(self, ls: List[int]) -> List[List[int]]:
# TODO: 增加是否属于同一文档的判断 # TODO: 增加是否属于同一文档的判断
lists = [] lists = []
...@@ -59,7 +38,11 @@ class MyFAISS(FAISS, VectorStore): ...@@ -59,7 +38,11 @@ class MyFAISS(FAISS, VectorStore):
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4 self, embedding: List[float], k: int = 4
) -> List[Document]: ) -> List[Document]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) faiss = dependable_faiss_import()
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
docs = [] docs = []
id_set = set() id_set = set()
store_len = len(self.index_to_docstore_id) store_len = len(self.index_to_docstore_id)
...@@ -69,7 +52,7 @@ class MyFAISS(FAISS, VectorStore): ...@@ -69,7 +52,7 @@ class MyFAISS(FAISS, VectorStore):
continue continue
_id = self.index_to_docstore_id[i] _id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id) doc = self.docstore.search(_id)
if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]): if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
if not isinstance(doc, Document): if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}") raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j]) doc.metadata["score"] = int(scores[0][j])
...@@ -79,11 +62,17 @@ class MyFAISS(FAISS, VectorStore): ...@@ -79,11 +62,17 @@ class MyFAISS(FAISS, VectorStore):
docs_len = len(doc.page_content) docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)): for k in range(1, max(i, store_len - i)):
break_flag = False break_flag = False
for l in [i + k, i - k]: if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward":
if 0 <= l < len(self.index_to_docstore_id): expand_range = [i + k]
elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward":
expand_range = [i - k]
else:
expand_range = [i + k, i - k]
for l in expand_range:
if l not in id_set and 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l] _id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0) doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size: if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != doc.metadata["source"]:
break_flag = True break_flag = True
break break
elif doc0.metadata["source"] == doc.metadata["source"]: elif doc0.metadata["source"] == doc.metadata["source"]:
...@@ -91,7 +80,7 @@ class MyFAISS(FAISS, VectorStore): ...@@ -91,7 +80,7 @@ class MyFAISS(FAISS, VectorStore):
id_set.add(l) id_set.add(l)
if break_flag: if break_flag:
break break
if (not self.chunk_conent) or ("add_context" in doc.metadata and doc.metadata["add_context"] == False): if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]):
return docs return docs
if len(id_set) == 0 and self.score_threshold > 0: if len(id_set) == 0 and self.score_threshold > 0:
return [] return []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论