提交 48ce8b1c 作者: hzg0601

pull dev

...@@ -23,13 +23,17 @@ ...@@ -23,13 +23,17 @@
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
🐳 Docker镜像:registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 (感谢 @InkSong🌲 )
💻 运行方式:docker run -d -p 80:7860 --gpus all registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM) 🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59) 📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
## 变更日志 ## 变更日志
参见 [变更日志](docs/CHANGELOG.md) 参见 [版本更新日志](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)
## 硬件需求 ## 硬件需求
...@@ -60,6 +64,23 @@ ...@@ -60,6 +64,23 @@
本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。 本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
## Docker 整合包
🐳 Docker镜像地址:`registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 `🌲
💻 一行命令运行:
```shell
docker run -d -p 80:7860 --gpus all registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0
```
- 该版本镜像大小`25.2G`,使用[v0.1.16](https://github.com/imClumsyPanda/langchain-ChatGLM/releases/tag/v0.1.16),以`nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04`为基础镜像
- 该版本内置两个`embedding`模型:`m3e-base``text2vec-large-chinese`,内置`fastchat+chatglm-6b`
- 该版本目标为方便一键部署使用,请确保您已经在Linux发行版上安装了NVIDIA驱动程序
- 请注意,您不需要在主机系统上安装CUDA工具包,但需要安装`NVIDIA Driver`以及`NVIDIA Container Toolkit`,请参考[安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
- 首次拉取和启动均需要一定时间,首次启动时请参照下图使用`docker logs -f <container id>`查看日志
- 如遇到启动过程卡在`Waiting..`步骤,建议使用`docker exec -it <container id> bash`进入`/logs/`目录查看对应阶段日志
![](img/docker_logs.png)
## Docker 部署 ## Docker 部署
为了能让容器使用主机GPU资源,需要在主机上安装 [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit)。具体安装步骤如下: 为了能让容器使用主机GPU资源,需要在主机上安装 [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit)。具体安装步骤如下:
```shell ```shell
...@@ -198,6 +219,7 @@ Web UI 可以实现如下功能: ...@@ -198,6 +219,7 @@ Web UI 可以实现如下功能:
- [ ] 知识图谱/图数据库接入 - [ ] 知识图谱/图数据库接入
- [ ] Agent 实现 - [ ] Agent 实现
- [x] 增加更多 LLM 模型支持 - [x] 增加更多 LLM 模型支持
- [x] [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
- [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) - [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
- [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8) - [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8)
- [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4) - [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4)
...@@ -221,7 +243,7 @@ Web UI 可以实现如下功能: ...@@ -221,7 +243,7 @@ Web UI 可以实现如下功能:
- [x] 选择知识库开始问答 - [x] 选择知识库开始问答
- [x] 上传文件/文件夹至知识库 - [x] 上传文件/文件夹至知识库
- [x] 知识库测试 - [x] 知识库测试
- [ ] 删除知识库中文件 - [x] 删除知识库中文件
- [x] 支持搜索引擎问答 - [x] 支持搜索引擎问答
- [ ] 增加 API 支持 - [ ] 增加 API 支持
- [x] 利用 fastapi 实现 API 部署方式 - [x] 利用 fastapi 实现 API 部署方式
...@@ -229,7 +251,7 @@ Web UI 可以实现如下功能: ...@@ -229,7 +251,7 @@ Web UI 可以实现如下功能:
- [x] VUE 前端 - [x] VUE 前端
## 项目交流群 ## 项目交流群
<img src="img/qr_code_32.jpg" alt="二维码" width="300" height="300" /> <img src="img/qr_code_42.jpg" alt="二维码" width="300" height="300" />
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
...@@ -17,6 +17,7 @@ import models.shared as shared ...@@ -17,6 +17,7 @@ import models.shared as shared
from agent import bing_search from agent import bing_search
from langchain.docstore.document import Document from langchain.docstore.document import Document
from functools import lru_cache from functools import lru_cache
from textsplitter.zh_title_enhance import zh_title_enhance
# patch HuggingFaceEmbeddings to make it hashable # patch HuggingFaceEmbeddings to make it hashable
...@@ -56,7 +57,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None): ...@@ -56,7 +57,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
return ret_list, [os.path.basename(p) for p in ret_list] return ret_list, [os.path.basename(p) for p in ret_list]
def load_file(filepath, sentence_size=SENTENCE_SIZE): def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
if filepath.lower().endswith(".md"): if filepath.lower().endswith(".md"):
loader = UnstructuredFileLoader(filepath, mode="elements") loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load() docs = loader.load()
...@@ -79,6 +80,8 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE): ...@@ -79,6 +80,8 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE):
loader = UnstructuredFileLoader(filepath, mode="elements") loader = UnstructuredFileLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter) docs = loader.load_and_split(text_splitter=textsplitter)
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
write_check_file(filepath, docs) write_check_file(filepath, docs)
return docs return docs
......
...@@ -42,7 +42,9 @@ def start(): ...@@ -42,7 +42,9 @@ def start():
@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help'])) @start.command(name="api", context_settings=dict(help_option_names=['-h', '--help']))
@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.') @click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.')
@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.') @click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.')
def start_api(ip, port): @click.option('-k', '--ssl_keyfile', type=int, help='enable api https/wss service, specify the ssl keyfile path.')
@click.option('-c', '--ssl_certfile', type=int, help='enable api https/wss service, specify the ssl certificate file path.')
def start_api(ip, port, **kwargs):
# 调用api_start之前需要先loadCheckPoint,并传入加载检查点的参数, # 调用api_start之前需要先loadCheckPoint,并传入加载检查点的参数,
# 理论上可以用click包进行包装,但过于繁琐,改动较大, # 理论上可以用click包进行包装,但过于繁琐,改动较大,
# 此处仍用parser包,并以models.loader.args.DEFAULT_ARGS的参数为默认参数 # 此处仍用parser包,并以models.loader.args.DEFAULT_ARGS的参数为默认参数
...@@ -51,7 +53,7 @@ def start_api(ip, port): ...@@ -51,7 +53,7 @@ def start_api(ip, port):
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.loader.args import DEFAULT_ARGS from models.loader.args import DEFAULT_ARGS
shared.loaderCheckPoint = LoaderCheckPoint(DEFAULT_ARGS) shared.loaderCheckPoint = LoaderCheckPoint(DEFAULT_ARGS)
api_start(host=ip, port=port) api_start(host=ip, port=port, **kwargs)
# # 通过cli.py调用cli_demo时需要在cli.py里初始化模型,否则会报错: # # 通过cli.py调用cli_demo时需要在cli.py里初始化模型,否则会报错:
# langchain-ChatGLM: error: unrecognized arguments: start cli # langchain-ChatGLM: error: unrecognized arguments: start cli
......
...@@ -27,7 +27,6 @@ EMBEDDING_MODEL = "text2vec" ...@@ -27,7 +27,6 @@ EMBEDDING_MODEL = "text2vec"
# Embedding running device # Embedding running device
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# supported LLM models # supported LLM models
# llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例 # llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
# 在以下字典中修改属性值,以指定本地 LLM 模型存储位置 # 在以下字典中修改属性值,以指定本地 LLM 模型存储位置
...@@ -58,12 +57,29 @@ llm_model_dict = { ...@@ -58,12 +57,29 @@ llm_model_dict = {
"local_model_path": None, "local_model_path": None,
"provides": "ChatGLM" "provides": "ChatGLM"
}, },
"chatglm2-6b": {
"name": "chatglm2-6b",
"pretrained_model_name": "THUDM/chatglm2-6b",
"local_model_path": None,
"provides": "ChatGLM"
},
"chatglm2-6b-int4": {
"name": "chatglm2-6b-int4",
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
"local_model_path": None,
"provides": "ChatGLM"
},
"chatglm2-6b-int8": {
"name": "chatglm2-6b-int8",
"pretrained_model_name": "THUDM/chatglm2-6b-int8",
"local_model_path": None,
"provides": "ChatGLM"
},
"chatyuan": { "chatyuan": {
"name": "chatyuan", "name": "chatyuan",
"pretrained_model_name": "ClueAI/ChatYuan-large-v2", "pretrained_model_name": "ClueAI/ChatYuan-large-v2",
"local_model_path": None, "local_model_path": None,
"provides": None "provides": "MOSSLLM"
}, },
"moss": { "moss": {
"name": "moss", "name": "moss",
...@@ -77,6 +93,46 @@ llm_model_dict = { ...@@ -77,6 +93,46 @@ llm_model_dict = {
"local_model_path": None, "local_model_path": None,
"provides": "LLamaLLM" "provides": "LLamaLLM"
}, },
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
# 如果仍然不行,则应该是网络加了防火墙(在服务器上这种情况比较常见),基本只能从别的设备上下载,
# 然后转移到目标设备了.
"bloomz-7b1": {
"name": "bloomz-7b1",
"pretrained_model_name": "bigscience/bloomz-7b1",
"local_model_path": None,
"provides": "MOSSLLM"
},
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
# 应与它要加载专有token有关
"bloom-3b": {
"name": "bloom-3b",
"pretrained_model_name": "bigscience/bloom-3b",
"local_model_path": None,
"provides": "MOSSLLM"
},
"baichuan-7b": {
"name": "baichuan-7b",
"pretrained_model_name": "baichuan-inc/baichuan-7B",
"local_model_path": None,
"provides": "MOSSLLM"
},
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
"ggml-vicuna-13b-1.1-q5": {
"name": "ggml-vicuna-13b-1.1-q5",
"pretrained_model_name": "lmsys/vicuna-13b-delta-v1.1",
# 这里需要下载好模型的路径,如果下载模型是默认路径则它会下载到用户工作区的
# /.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/
# 还有就是由于本项目加载模型的方式设置的比较严格,下载完成后仍需手动修改模型的文件名
# 将其设置为与Huggface Hub一致的文件名
# 此外不同时期的ggml格式并不兼容,因此不同时期的ggml需要安装不同的llama-cpp-python库,且实测pip install 不好使
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
"local_model_path": f'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/''',
"provides": "LLamaLLM"
},
# 通过 fastchat 调用的模型请参考如下格式 # 通过 fastchat 调用的模型请参考如下格式
"fastchat-chatglm-6b": { "fastchat-chatglm-6b": {
...@@ -84,6 +140,14 @@ llm_model_dict = { ...@@ -84,6 +140,14 @@ llm_model_dict = {
"pretrained_model_name": "chatglm-6b", "pretrained_model_name": "chatglm-6b",
"local_model_path": None, "local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM" "provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
"fastchat-chatglm2-6b": {
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "chatglm2-6b",
"local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
"api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url" "api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url"
}, },
...@@ -93,8 +157,18 @@ llm_model_dict = { ...@@ -93,8 +157,18 @@ llm_model_dict = {
"pretrained_model_name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None, "local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM" "provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
"api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url" "api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
}, },
"openai-chatgpt-3.5": {
"name": "gpt-3.5-turbo",
"pretrained_model_name": "gpt-3.5-turbo",
"provides": "FastChatOpenAILLM",
"local_model_path": None,
"api_base_url": "https://api.openapi.com/v1",
"api_key": ""
},
} }
# LLM 名称 # LLM 名称
...@@ -128,7 +202,7 @@ PROMPT_TEMPLATE = """已知信息: ...@@ -128,7 +202,7 @@ PROMPT_TEMPLATE = """已知信息:
根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}""" 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
# 缓存知识库数量 # 缓存知识库数量,如果是ChatGLM2,ChatGLM2-int4,ChatGLM2-int8模型若检索效果不好可以调成’10’
CACHED_VS_NUM = 1 CACHED_VS_NUM = 1
# 文本分句长度 # 文本分句长度
...@@ -174,3 +248,8 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" ...@@ -174,3 +248,8 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out # 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG # 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
BING_SUBSCRIPTION_KEY = "" BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
...@@ -44,4 +44,12 @@ $ pip install -r requirements.txt ...@@ -44,4 +44,12 @@ $ pip install -r requirements.txt
$ python loader/image_loader.py $ python loader/image_loader.py
``` ```
注:使用 `langchain.document_loaders.UnstructuredFileLoader` 进行非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html) 注:使用 `langchain.document_loaders.UnstructuredFileLoader` 进行非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)
## llama-cpp模型调用的说明
1. 首先从huggingface hub中下载对应的模型,如https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/的[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。
2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列,因此需要重命名为原始文件名,如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)
3. 基于下载模型的ggml的加载时间,推测对应的llama-cpp版本,下载对应的llama-cpp-python库的wheel文件,实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。
4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错.
# 启动API服务
# 启动API服务
## 通过py文件启动
可以通过直接执行`api.py`文件启动API服务,默认以ip:0.0.0.0和port:7861启动http和ws服务。
```shell
python api.py
```
同时,启动时支持StartOption所列的模型加载参数,同时还支持IP和端口设置。
```shell
python api.py --model-name chatglm-6b-int8 --port 7862
```
## 通过cli.bat/cli.sh启动
也可以通过命令行控制文件继续启动。
```shell
cli.sh api --help
```
其他可设置参数和上述py文件启动方式相同。
# 以https、wss启动API服务
## 本地创建ssl相关证书文件
如果没有正式签发的CA证书,可以[安装mkcert](https://github.com/FiloSottile/mkcert#installation)工具, 然后用如下指令生成本地CA证书:
```shell
mkcert -install
mkcert api.example.com 47.123.123.123 localhost 127.0.0.1 ::1
```
默认回车保存在当前目录下,会有以生成指令第一个域名命名为前缀命名的两个pem文件。
附带两个文件参数启动即可。
````shell
python api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem
./cli.sh api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem
````
此外可以通过前置Nginx转发实现类似效果,可另行查阅相关资料。
\ No newline at end of file
...@@ -5,9 +5,6 @@ from langchain.document_loaders.unstructured import UnstructuredFileLoader ...@@ -5,9 +5,6 @@ from langchain.document_loaders.unstructured import UnstructuredFileLoader
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
import os import os
import nltk import nltk
from configs.model_config import NLTK_DATA_PATH
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
class UnstructuredPaddleImageLoader(UnstructuredFileLoader): class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load image files, such as PNGs and JPGs.""" """Loader that uses unstructured to load image files, such as PNGs and JPGs."""
...@@ -35,6 +32,10 @@ class UnstructuredPaddleImageLoader(UnstructuredFileLoader): ...@@ -35,6 +32,10 @@ class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg") filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg")
loader = UnstructuredPaddleImageLoader(filepath, mode="elements") loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
docs = loader.load() docs = loader.load()
......
...@@ -65,6 +65,7 @@ class ChatGLM(BaseAnswer, LLM, ABC): ...@@ -65,6 +65,7 @@ class ChatGLM(BaseAnswer, LLM, ABC):
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": stream_resp} answer_result.llm_output = {"answer": stream_resp}
yield answer_result yield answer_result
self.checkPoint.clear_torch_cache()
else: else:
response, _ = self.checkPoint.model.chat( response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer, self.checkPoint.tokenizer,
......
...@@ -23,6 +23,7 @@ def _build_message_template() -> Dict[str, str]: ...@@ -23,6 +23,7 @@ def _build_message_template() -> Dict[str, str]:
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
api_base_url: str = "http://localhost:8000/v1" api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b" model_name: str = "chatglm-6b"
max_token: int = 10000 max_token: int = 10000
...@@ -31,8 +32,14 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -31,8 +32,14 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
checkPoint: LoaderCheckPoint = None checkPoint: LoaderCheckPoint = None
history = [] history = []
history_len: int = 10 history_len: int = 10
api_key: str = ""
def __init__(self, checkPoint: LoaderCheckPoint = None):
def __init__(self,
checkPoint: LoaderCheckPoint = None,
# api_base_url:str="http://localhost:8000/v1",
# model_name:str="chatglm-6b",
# api_key:str=""
):
super().__init__() super().__init__()
self.checkPoint = checkPoint self.checkPoint = checkPoint
...@@ -60,7 +67,7 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -60,7 +67,7 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
return self.api_base_url return self.api_base_url
def set_api_key(self, api_key: str): def set_api_key(self, api_key: str):
pass self.api_key = api_key
def set_api_base_url(self, api_base_url: str): def set_api_base_url(self, api_base_url: str):
self.api_base_url = api_base_url self.api_base_url = api_base_url
...@@ -73,7 +80,8 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -73,7 +80,8 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
try: try:
import openai import openai
# Not support yet # Not support yet
openai.api_key = "EMPTY" # openai.api_key = "EMPTY"
openai.key = self.api_key
openai.api_base = self.api_base_url openai.api_base = self.api_base_url
except ImportError: except ImportError:
raise ValueError( raise ValueError(
...@@ -116,7 +124,8 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): ...@@ -116,7 +124,8 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
try: try:
import openai import openai
# Not support yet # Not support yet
openai.api_key = "EMPTY" # openai.api_key = "EMPTY"
openai.api_key = self.api_key
openai.api_base = self.api_base_url openai.api_base = self.api_base_url
except ImportError: except ImportError:
raise ValueError( raise ValueError(
......
...@@ -6,14 +6,17 @@ import torch ...@@ -6,14 +6,17 @@ import torch
import transformers import transformers
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any,Union
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult) AnswerResult)
class InvalidScoreLogitsProcessor(LogitsProcessor): class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: Union[torch.LongTensor,list], scores: Union[torch.FloatTensor,list]) -> torch.FloatTensor:
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
input_ids = torch.tensor(input_ids) if isinstance(input_ids,list) else input_ids
scores = torch.tensor(scores) if isinstance(scores,list) else scores
if torch.isnan(scores).any() or torch.isinf(scores).any(): if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_() scores.zero_()
scores[..., 5] = 5e4 scores[..., 5] = 5e4
...@@ -163,7 +166,20 @@ class LLamaLLM(BaseAnswer, LLM, ABC): ...@@ -163,7 +166,20 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
self.stopping_criteria = transformers.StoppingCriteriaList() self.stopping_criteria = transformers.StoppingCriteriaList()
# 观测输出 # 观测输出
gen_kwargs.update({'stopping_criteria': self.stopping_criteria}) gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
# 因此需要先判断模型是否是llama-cpp模型,然后取gen_kwargs与模型generate方法字段的交集
# 仅将交集字段传给模型以保证兼容性
# todo llama-cpp模型在本框架下兼容性较差,后续可以考虑重写一个llama_cpp_llm.py模块
if "llama_cpp" in self.checkPoint.model.__str__():
import inspect
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args)&set(gen_kwargs.keys())
common_kwargs = {key:gen_kwargs[key] for key in common_kwargs_keys}
#? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
#?为什么会不支持GPU呢,不应该啊?
output_ids = torch.tensor([list(self.checkPoint.model.generate(input_id_i.cpu(),**common_kwargs)) for input_id_i in input_ids])
else:
output_ids = self.checkPoint.model.generate(**gen_kwargs) output_ids = self.checkPoint.model.generate(**gen_kwargs)
new_tokens = len(output_ids[0]) - len(input_ids[0]) new_tokens = len(output_ids[0]) - len(input_ids[0])
reply = self.decode(output_ids[0][-new_tokens:]) reply = self.decode(output_ids[0][-new_tokens:])
......
...@@ -67,9 +67,11 @@ class LoaderCheckPoint: ...@@ -67,9 +67,11 @@ class LoaderCheckPoint:
self.load_in_8bit = params.get('load_in_8bit', False) self.load_in_8bit = params.get('load_in_8bit', False)
self.bf16 = params.get('bf16', False) self.bf16 = params.get('bf16', False)
def _load_model_config(self, model_name): def _load_model_config(self, model_name):
if self.model_path: if self.model_path:
self.model_path = re.sub("\s","",self.model_path)
checkpoint = Path(f'{self.model_path}') checkpoint = Path(f'{self.model_path}')
else: else:
if not self.no_remote_model: if not self.no_remote_model:
...@@ -78,10 +80,12 @@ class LoaderCheckPoint: ...@@ -78,10 +80,12 @@ class LoaderCheckPoint:
raise ValueError( raise ValueError(
"本地模型local_model_path未配置路径" "本地模型local_model_path未配置路径"
) )
try:
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
return model_config return model_config
except Exception as e:
print(e)
return checkpoint
def _load_model(self, model_name): def _load_model(self, model_name):
""" """
...@@ -93,6 +97,7 @@ class LoaderCheckPoint: ...@@ -93,6 +97,7 @@ class LoaderCheckPoint:
t0 = time.time() t0 = time.time()
if self.model_path: if self.model_path:
self.model_path = re.sub("\s","",self.model_path)
checkpoint = Path(f'{self.model_path}') checkpoint = Path(f'{self.model_path}')
else: else:
if not self.no_remote_model: if not self.no_remote_model:
...@@ -103,7 +108,7 @@ class LoaderCheckPoint: ...@@ -103,7 +108,7 @@ class LoaderCheckPoint:
) )
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0 self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
if 'chatglm' in model_name.lower(): if 'chatglm' in model_name.lower() or "chatyuan" in model_name.lower():
LoaderClass = AutoModel LoaderClass = AutoModel
else: else:
LoaderClass = AutoModelForCausalLM LoaderClass = AutoModelForCausalLM
...@@ -126,8 +131,14 @@ class LoaderCheckPoint: ...@@ -126,8 +131,14 @@ class LoaderCheckPoint:
.half() .half()
.cuda() .cuda()
) )
# 支持自定义cuda设备
elif ":" in self.llm_device:
model = LoaderClass.from_pretrained(checkpoint,
config=self.model_config,
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
trust_remote_code=True).half().to(self.llm_device)
else: else:
from accelerate import dispatch_model from accelerate import dispatch_model,infer_auto_device_map
model = LoaderClass.from_pretrained(checkpoint, model = LoaderClass.from_pretrained(checkpoint,
config=self.model_config, config=self.model_config,
...@@ -151,6 +162,13 @@ class LoaderCheckPoint: ...@@ -151,6 +162,13 @@ class LoaderCheckPoint:
dtype=torch.float16 if not self.load_in_8bit else torch.int8, dtype=torch.float16 if not self.load_in_8bit else torch.int8,
max_memory=max_memory, max_memory=max_memory,
no_split_module_classes=model._no_split_modules) no_split_module_classes=model._no_split_modules)
# 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
# 实测在bloom模型上如此
# self.device_map = infer_auto_device_map(model,
# dtype=torch.int8,
# no_split_module_classes=model._no_split_modules)
model = dispatch_model(model, device_map=self.device_map) model = dispatch_model(model, device_map=self.device_map)
else: else:
...@@ -166,7 +184,7 @@ class LoaderCheckPoint: ...@@ -166,7 +184,7 @@ class LoaderCheckPoint:
elif self.is_llamacpp: elif self.is_llamacpp:
try: try:
from models.extensions.llamacpp_model_alternative import LlamaCppModel from llama_cpp import Llama
except ImportError as exc: except ImportError as exc:
raise ValueError( raise ValueError(
...@@ -177,7 +195,16 @@ class LoaderCheckPoint: ...@@ -177,7 +195,16 @@ class LoaderCheckPoint:
model_file = list(checkpoint.glob('ggml*.bin'))[0] model_file = list(checkpoint.glob('ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n") print(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file) model = Llama(model_path=model_file._str)
# 实测llama-cpp-vicuna13b-q5_1的AutoTokenizer加载tokenizer的速度极慢,应存在优化空间
# 但需要对huggingface的AutoTokenizer进行优化
# tokenizer = model.tokenizer
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
#* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
return model, tokenizer return model, tokenizer
elif self.load_in_8bit: elif self.load_in_8bit:
...@@ -267,10 +294,21 @@ class LoaderCheckPoint: ...@@ -267,10 +294,21 @@ class LoaderCheckPoint:
# 在调用chat或者stream_chat时,input_ids会被放到model.device上 # 在调用chat或者stream_chat时,input_ids会被放到model.device上
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
encode = ""
if 'chatglm2' in self.model_name:
device_map = {
f"{layer_prefix}.embedding.word_embeddings": 0,
f"{layer_prefix}.rotary_pos_emb": 0,
f"{layer_prefix}.output_layer": 0,
f"{layer_prefix}.encoder.final_layernorm": 0,
f"base_model.model.output_layer": 0
}
encode = ".encoder"
else:
device_map = {f'{layer_prefix}.word_embeddings': 0, device_map = {f'{layer_prefix}.word_embeddings': 0,
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0, f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
f'base_model.model.lm_head': 0, } f'base_model.model.lm_head': 0, }
used = 2 used = 2
gpu_target = 0 gpu_target = 0
for i in range(num_trans_layers): for i in range(num_trans_layers):
...@@ -278,7 +316,7 @@ class LoaderCheckPoint: ...@@ -278,7 +316,7 @@ class LoaderCheckPoint:
gpu_target += 1 gpu_target += 1
used = 0 used = 0
assert gpu_target < num_gpus assert gpu_target < num_gpus
device_map[f'{layer_prefix}.layers.{i}'] = gpu_target device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target
used += 1 used += 1
return device_map return device_map
...@@ -395,7 +433,7 @@ class LoaderCheckPoint: ...@@ -395,7 +433,7 @@ class LoaderCheckPoint:
print( print(
"如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") "如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
elif torch.has_cuda: elif torch.has_cuda:
device_id = "0" if torch.cuda.is_available() else None device_id = "0" if torch.cuda.is_available() and (":" not in self.llm_device) else None
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
with torch.cuda.device(CUDA_DEVICE): with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -442,5 +480,6 @@ class LoaderCheckPoint: ...@@ -442,5 +480,6 @@ class LoaderCheckPoint:
self.model.transformer.prefix_encoder.float() self.model.transformer.prefix_encoder.float()
except Exception as e: except Exception as e:
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")
# llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法
if not self.is_llamacpp:
self.model = self.model.eval() self.model = self.model.eval()
...@@ -6,7 +6,7 @@ from models.base import (BaseAnswer, ...@@ -6,7 +6,7 @@ from models.base import (BaseAnswer,
AnswerResult) AnswerResult)
import torch import torch
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
META_INSTRUCTION = \ META_INSTRUCTION = \
"""You are an AI assistant whose name is MOSS. """You are an AI assistant whose name is MOSS.
- MOSS is a conversational language model that is developed by Fudan University. It is designed to be helpful, honest, and harmless. - MOSS is a conversational language model that is developed by Fudan University. It is designed to be helpful, honest, and harmless.
...@@ -20,7 +20,7 @@ META_INSTRUCTION = \ ...@@ -20,7 +20,7 @@ META_INSTRUCTION = \
Capabilities and tools that MOSS can possess. Capabilities and tools that MOSS can possess.
""" """
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
class MOSSLLM(BaseAnswer, LLM, ABC): class MOSSLLM(BaseAnswer, LLM, ABC):
max_token: int = 2048 max_token: int = 2048
temperature: float = 0.7 temperature: float = 0.7
...@@ -42,10 +42,11 @@ class MOSSLLM(BaseAnswer, LLM, ABC): ...@@ -42,10 +42,11 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
return self.checkPoint return self.checkPoint
@property @property
def set_history_len(self) -> int: def _history_len(self) -> int:
return self.history_len return self.history_len
def _set_history_len(self, history_len: int) -> None: def set_history_len(self, history_len: int) -> None:
self.history_len = history_len self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
...@@ -59,11 +60,13 @@ class MOSSLLM(BaseAnswer, LLM, ABC): ...@@ -59,11 +60,13 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
prompt_w_history = str(history) prompt_w_history = str(history)
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>' prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
else: else:
prompt_w_history = META_INSTRUCTION prompt_w_history = META_INSTRUCTION.replace("MOSS", self.checkPoint.model_name.split("/")[-1])
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>' prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt") inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
# max_length似乎可以设的小一些,而repetion_penalty应大一些,否则chatyuan,bloom等模型为满足max会重复输出
#
outputs = self.checkPoint.model.generate( outputs = self.checkPoint.model.generate(
inputs.input_ids.cuda(), inputs.input_ids.cuda(),
attention_mask=inputs.attention_mask.cuda(), attention_mask=inputs.attention_mask.cuda(),
......
...@@ -44,4 +44,5 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ ...@@ -44,4 +44,5 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if 'FastChatOpenAILLM' in llm_model_info["provides"]: if 'FastChatOpenAILLM' in llm_model_info["provides"]:
modelInsLLM.set_api_base_url(llm_model_info['api_base_url']) modelInsLLM.set_api_base_url(llm_model_info['api_base_url'])
modelInsLLM.call_model_name(llm_model_info['name']) modelInsLLM.call_model_name(llm_model_info['name'])
modelInsLLM.set_api_key(llm_model_info['api_key'])
return modelInsLLM return modelInsLLM
...@@ -23,9 +23,13 @@ openai ...@@ -23,9 +23,13 @@ openai
#accelerate~=0.18.0 #accelerate~=0.18.0
#peft~=0.3.0 #peft~=0.3.0
#bitsandbytes; platform_system != "Windows" #bitsandbytes; platform_system != "Windows"
#llama-cpp-python==0.1.34; platform_system != "Windows"
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
# 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库
# but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载
# 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定
# 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容
# 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用,
# 建议如非必要还是不要使用llama-cpp
torch~=2.0.0 torch~=2.0.0
pydantic~=1.10.7 pydantic~=1.10.7
starlette~=0.26.1 starlette~=0.26.1
...@@ -33,5 +37,4 @@ numpy~=1.23.5 ...@@ -33,5 +37,4 @@ numpy~=1.23.5
tqdm~=4.65.0 tqdm~=4.65.0
requests~=2.28.2 requests~=2.28.2
tenacity~=8.2.2 tenacity~=8.2.2
# 默认下载的charset_normalizer模块版本过高会抛出,`artially initialized module 'charset_normalizer' has no attribute 'md__mypyc' (most likely due to a circular import)`
charset_normalizer==2.1.0 charset_normalizer==2.1.0
\ No newline at end of file
from configs.model_config import *
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
import nltk
from vectorstores import MyFAISS
from chains.local_doc_qa import load_file
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
if __name__ == "__main__":
filepath = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"knowledge_base", "samples", "content", "test.txt")
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
model_kwargs={'device': EMBEDDING_DEVICE})
docs = load_file(filepath, using_zh_title_enhance=True)
vector_store = MyFAISS.from_documents(docs, embeddings)
query = "指令提示技术有什么示例"
search_result = vector_store.similarity_search(query)
print(search_result)
pass
from .chinese_text_splitter import ChineseTextSplitter from .chinese_text_splitter import ChineseTextSplitter
from .ali_text_splitter import AliTextSplitter from .ali_text_splitter import AliTextSplitter
from .zh_title_enhance import zh_title_enhance
\ No newline at end of file
from langchain.docstore.document import Document
import re
def under_non_alpha_ratio(text: str, threshold: float = 0.5):
"""Checks if the proportion of non-alpha characters in the text snippet exceeds a given
threshold. This helps prevent text like "-----------BREAK---------" from being tagged
as a title or narrative text. The ratio does not count spaces.
Parameters
----------
text
The input string to test
threshold
If the proportion of non-alpha characters exceeds this threshold, the function
returns False
"""
if len(text) == 0:
return False
alpha_count = len([char for char in text if char.strip() and char.isalpha()])
total_count = len([char for char in text if char.strip()])
try:
ratio = alpha_count / total_count
return ratio < threshold
except:
return False
def is_possible_title(
text: str,
title_max_word_length: int = 20,
non_alpha_threshold: float = 0.5,
) -> bool:
"""Checks to see if the text passes all of the checks for a valid title.
Parameters
----------
text
The input text to check
title_max_word_length
The maximum number of words a title can contain
non_alpha_threshold
The minimum number of alpha characters the text needs to be considered a title
"""
# 文本长度为0的话,肯定不是title
if len(text) == 0:
print("Not a title. Text is empty.")
return False
# 文本中有标点符号,就不是title
ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
if ENDS_IN_PUNCT_RE.search(text) is not None:
return False
# 文本长度不能超过设定值,默认20
# NOTE(robinson) - splitting on spaces here instead of word tokenizing because it
# is less expensive and actual tokenization doesn't add much value for the length check
if len(text) > title_max_word_length:
return False
# 文本中数字的占比不能太高,否则不是title
if under_non_alpha_ratio(text, threshold=non_alpha_threshold):
return False
# NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles
if text.endswith((",", ".", ",", "。")):
return False
if text.isnumeric():
print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore
return False
# 开头的字符内应该有数字,默认5个字符内
if len(text) < 5:
text_5 = text
else:
text_5 = text[:5]
alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5))))
if not alpha_in_text_5:
return False
return True
def zh_title_enhance(docs: Document) -> Document:
title = None
if len(docs) > 0:
for doc in docs:
if is_possible_title(doc.page_content):
doc.metadata['category'] = 'cn_Title'
title = doc.page_content
elif title:
doc.page_content = f"下文与({title})有关。{doc.page_content}"
return docs
else:
print("文件不存在")
...@@ -6,6 +6,8 @@ from langchain.docstore.base import Docstore ...@@ -6,6 +6,8 @@ from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
import numpy as np import numpy as np
import copy import copy
import os
from configs.model_config import *
class MyFAISS(FAISS, VectorStore): class MyFAISS(FAISS, VectorStore):
...@@ -22,6 +24,9 @@ class MyFAISS(FAISS, VectorStore): ...@@ -22,6 +24,9 @@ class MyFAISS(FAISS, VectorStore):
docstore=docstore, docstore=docstore,
index_to_docstore_id=index_to_docstore_id, index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2) normalize_L2=normalize_L2)
self.score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD
self.chunk_size = CHUNK_SIZE
self.chunk_conent = False
def seperate_list(self, ls: List[int]) -> List[List[int]]: def seperate_list(self, ls: List[int]) -> List[List[int]]:
# TODO: 增加是否属于同一文档的判断 # TODO: 增加是否属于同一文档的判断
...@@ -52,7 +57,11 @@ class MyFAISS(FAISS, VectorStore): ...@@ -52,7 +57,11 @@ class MyFAISS(FAISS, VectorStore):
if i == -1 or 0 < self.score_threshold < scores[0][j]: if i == -1 or 0 < self.score_threshold < scores[0][j]:
# This happens when not enough docs are returned. # This happens when not enough docs are returned.
continue continue
if i in self.index_to_docstore_id:
_id = self.index_to_docstore_id[i] _id = self.index_to_docstore_id[i]
# 执行接下来的操作
else:
continue
doc = self.docstore.search(_id) doc = self.docstore.search(_id)
if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]): if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
# 匹配出的文本如果不需要扩展上下文则执行如下代码 # 匹配出的文本如果不需要扩展上下文则执行如下代码
...@@ -113,8 +122,10 @@ class MyFAISS(FAISS, VectorStore): ...@@ -113,8 +122,10 @@ class MyFAISS(FAISS, VectorStore):
try: try:
if isinstance(source, str): if isinstance(source, str):
ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] == source] ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] == source]
vs_path = os.path.join(os.path.split(os.path.split(source)[0])[0], "vector_store")
else: else:
ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] in source] ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] in source]
vs_path = os.path.join(os.path.split(os.path.split(source[0])[0])[0], "vector_store")
if len(ids) == 0: if len(ids) == 0:
return f"docs delete fail" return f"docs delete fail"
else: else:
...@@ -122,6 +133,9 @@ class MyFAISS(FAISS, VectorStore): ...@@ -122,6 +133,9 @@ class MyFAISS(FAISS, VectorStore):
index = list(self.index_to_docstore_id.keys())[list(self.index_to_docstore_id.values()).index(id)] index = list(self.index_to_docstore_id.keys())[list(self.index_to_docstore_id.values()).index(id)]
self.index_to_docstore_id.pop(index) self.index_to_docstore_id.pop(index)
self.docstore._dict.pop(id) self.docstore._dict.pop(id)
# TODO: 从 self.index 中删除对应id
# self.index.reset()
self.save_local(vs_path)
return f"docs delete success" return f"docs delete success"
except Exception as e: except Exception as e:
print(e) print(e)
......
...@@ -15,7 +15,7 @@ COPY . /app ...@@ -15,7 +15,7 @@ COPY . /app
RUN pnpm run build RUN pnpm run build
FROM frontend AS final
COPY --from=frontend /app/dist /app/public COPY --from=frontend /app/dist /app/public
......
...@@ -16,6 +16,24 @@ export const chatfile = (params: any) => { ...@@ -16,6 +16,24 @@ export const chatfile = (params: any) => {
}) })
} }
export const getKbsList = () => {
return api({
url: '/local_doc_qa/list_knowledge_base',
method: 'get',
})
}
export const deleteKb = (knowledge_base_id: any) => {
return api({
url: '/local_doc_qa/delete_knowledge_base',
method: 'delete',
params: {
knowledge_base_id,
},
})
}
export const getfilelist = (knowledge_base_id: any) => { export const getfilelist = (knowledge_base_id: any) => {
return api({ return api({
url: '/local_doc_qa/list_files', url: '/local_doc_qa/list_files',
...@@ -35,8 +53,8 @@ export const bing_search = (params: any) => { ...@@ -35,8 +53,8 @@ export const bing_search = (params: any) => {
export const deletefile = (params: any) => { export const deletefile = (params: any) => {
return api({ return api({
url: '/local_doc_qa/delete_file', url: '/local_doc_qa/delete_file',
method: 'post', method: 'delete',
data: JSON.stringify(params), params,
}) })
} }
export const web_url = () => { export const web_url = () => {
...@@ -45,3 +63,18 @@ export const web_url = () => { ...@@ -45,3 +63,18 @@ export const web_url = () => {
export const setapi = () => { export const setapi = () => {
return window.baseApi return window.baseApi
} }
export const getkblist = (knowledge_base_id: any) => {
return api({
url: '/local_doc_qa/list_knowledge_base',
method: 'get',
params: {},
})
}
export const deletekb = (params: any) => {
return api({
url: '/local_doc_qa/delete_knowledge_base',
method: 'post',
data: JSON.stringify(params),
})
}
...@@ -555,7 +555,7 @@ const options = computed(() => { ...@@ -555,7 +555,7 @@ const options = computed(() => {
return common return common
}) })
function handleSelect(key: 'copyText' | 'delete' | 'toggleRenderType') { function handleSelect(key: string) {
if (key == '清除会话') { if (key == '清除会话') {
handleClear() handleClear()
} }
...@@ -658,7 +658,6 @@ function searchfun() { ...@@ -658,7 +658,6 @@ function searchfun() {
<NDropdown <NDropdown
v-if="isMobile" v-if="isMobile"
:trigger="isMobile ? 'click' : 'hover'" :trigger="isMobile ? 'click' : 'hover'"
:placement="!inversion ? 'right' : 'left'"
:options="options" :options="options"
@select="handleSelect" @select="handleSelect"
> >
......
...@@ -3,15 +3,16 @@ import { NButton, NForm, NFormItem, NInput, NPopconfirm } from 'naive-ui' ...@@ -3,15 +3,16 @@ import { NButton, NForm, NFormItem, NInput, NPopconfirm } from 'naive-ui'
import { onMounted, ref } from 'vue' import { onMounted, ref } from 'vue'
import filelist from './filelist.vue' import filelist from './filelist.vue'
import { SvgIcon } from '@/components/common' import { SvgIcon } from '@/components/common'
import { deletefile, getfilelist } from '@/api/chat' import { deleteKb, getKbsList } from '@/api/chat'
import { idStore } from '@/store/modules/knowledgebaseid/id' import { idStore } from '@/store/modules/knowledgebaseid/id'
const items = ref<any>([]) const items = ref<any>([])
const choice = ref('') const choice = ref('')
const store = idStore() const store = idStore()
onMounted(async () => { onMounted(async () => {
choice.value = store.knowledgeid choice.value = store.knowledgeid
const res = await getfilelist({}) const res = await getKbsList()
res.data.data.forEach((item: any) => { res.data.data.forEach((item: any) => {
items.value.push({ items.value.push({
value: item, value: item,
...@@ -52,8 +53,8 @@ const handleClick = () => { ...@@ -52,8 +53,8 @@ const handleClick = () => {
} }
} }
async function handleDelete(item: any) { async function handleDelete(item: any) {
await deletefile(item.value) await deleteKb(item.value)
const res = await getfilelist({}) const res = await getKbsList()
items.value = [] items.value = []
res.data.data.forEach((item: any) => { res.data.data.forEach((item: any) => {
items.value.push({ items.value.push({
......
...@@ -218,7 +218,12 @@ def change_chunk_conent(mode, label_conent, history): ...@@ -218,7 +218,12 @@ def change_chunk_conent(mode, label_conent, history):
def add_vs_name(vs_name, chatbot): def add_vs_name(vs_name, chatbot):
if vs_name in get_vs_list(): if vs_name is None or vs_name.strip() == "" :
vs_status = "知识库名称不能为空,请重新填写知识库名称"
chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
visible=False), chatbot, gr.update(visible=False)
elif vs_name in get_vs_list():
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交" vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update( return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
......
...@@ -143,7 +143,7 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec' ...@@ -143,7 +143,7 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
# return history + [[None, model_status]] # return history + [[None, model_status]]
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): def get_vector_store(local_doc_qa, vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
filelist = [] filelist = []
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
...@@ -455,6 +455,8 @@ with st.sidebar: ...@@ -455,6 +455,8 @@ with st.sidebar:
cols = st.columns([12, 10]) cols = st.columns([12, 10])
kb_name = cols[0].text_input( kb_name = cols[0].text_input(
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed') '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed')
if 'kb_name' not in st.session_state:
st.session_state.kb_name = kb_name
cols[1].button('新建知识库', on_click=on_new_kb) cols[1].button('新建知识库', on_click=on_new_kb)
vs_path = st.selectbox( vs_path = st.selectbox(
'选择知识库', vs_list, on_change=on_vs_change, key='vs_path') '选择知识库', vs_list, on_change=on_vs_change, key='vs_path')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论