提交 13b41553 作者: imClumsyPanda

update project to v0.1.3

# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*/**/__pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
...@@ -164,3 +165,6 @@ output/* ...@@ -164,3 +165,6 @@ output/*
log/* log/*
.chroma .chroma
vector_store/* vector_store/*
llm/*
embedding/*
\ No newline at end of file
# 贡献指南
欢迎!我们是一个非常友好的社区,非常高兴您想要帮助我们让这个应用程序变得更好。但是,请您遵循一些通用准则以保持组织有序。
1. 确保为您要修复的错误或要添加的功能创建了一个[问题](https://github.com/imClumsyPanda/langchain-ChatGLM/issues),尽可能保持它们小。
2. 请使用 `git pull --rebase` 来拉取和衍合上游的更新。
3. 将提交合并为格式良好的提交。在提交说明中单独一行提到要解决的问题,如`Fix #<bug>`(有关更多可以使用的关键字,请参见[将拉取请求链接到问题](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))。
4. 推送到`dev`。在说明中提到正在解决的问题。
---
# Contribution Guide
Welcome! We're a pretty friendly community, and we're thrilled that you want to help make this app even better. However, we ask that you follow some general guidelines to keep things organized around here.
1. Make sure an [issue](https://github.com/imClumsyPanda/langchain-ChatGLM/issues) is created for the bug you're about to fix, or feature you're about to add. Keep them as small as possible.
2. Please use `git pull --rebase` to fetch and merge updates from the upstream.
3. Rebase commits into well-formatted commits. Mention the issue being resolved in the commit message on a line all by itself like `Fixes #<bug>` (refer to [Linking a pull request to an issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) for more keywords you can use).
4. Push into `dev`. Mention which bug is being resolved in the description.
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
🚩 This project does not involve fine-tuning or training; however, fine-tuning or training can be employed to optimize the effectiveness of this project. 🚩 This project does not involve fine-tuning or training; however, fine-tuning or training can be employed to optimize the effectiveness of this project.
📓 [ModelWhale online notebook](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
## Changelog ## Changelog
...@@ -115,7 +116,7 @@ python webui.py ...@@ -115,7 +116,7 @@ python webui.py
Note: Before executing, check the remaining space in the `$HOME/.cache/huggingface/` folder, at least 15G. Note: Before executing, check the remaining space in the `$HOME/.cache/huggingface/` folder, at least 15G.
The resulting interface is shown below: The resulting interface is shown below:
![webui](img/ui1.png) ![webui](img/webui_0419.png)
The Web UI supports the following features: The Web UI supports the following features:
1. Automatically reads the `LLM` and `embedding` model enumerations in `configs/model_config.py`, allowing you to select and reload the model by clicking `重新加载模型`. 1. Automatically reads the `LLM` and `embedding` model enumerations in `configs/model_config.py`, allowing you to select and reload the model by clicking `重新加载模型`.
...@@ -207,5 +208,12 @@ ChatGLM's answer after using LangChain to access the README.md file of the ChatG ...@@ -207,5 +208,12 @@ ChatGLM's answer after using LangChain to access the README.md file of the ChatG
- [ ] Add Web UI DEMO - [ ] Add Web UI DEMO
- [x] Implement Web UI DEMO using Gradio - [x] Implement Web UI DEMO using Gradio
- [x] Add output and error messages - [x] Add output and error messages
- [ ] Citation callout - [x] Citation callout
- [ ] Use FastAPI to implement API deployment method and develop a Web UI DEMO for API calls - [ ] Knowledge base management
- [x] QA based on selected knowledge base
- [x] Add files/folder to knowledge base
- [ ] Add files/folder to knowledge base
- [ ] Implement Web UI DEMO using Streamlit
- [ ] Add support for API deployment
- [x] Use fastapi to implement API
- [ ] Implement Web UI DEMO for API calls
from configs.model_config import *
from chains.local_doc_qa import LocalDocQA
import os
import nltk
import uvicorn
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
from starlette.responses import RedirectResponse
app = FastAPI()
global local_doc_qa, vs_path
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 10
# LLM input history length
LLM_HISTORY_LEN = 3
# Show reply with source text from input document
REPLY_WITH_SOURCE = False
class Query(BaseModel):
query: str
@app.get('/')
async def document():
return RedirectResponse(url="/docs")
@app.on_event("startup")
async def get_local_doc_qa():
global local_doc_qa
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=LLM_MODEL,
embedding_model=EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
llm_history_len=LLM_HISTORY_LEN,
top_k=VECTOR_SEARCH_TOP_K)
@app.post("/file")
async def upload_file(UserFile: UploadFile=File(...)):
global vs_path
response = {
"msg": None,
"status": 0
}
try:
filepath = './content/' + UserFile.filename
content = await UserFile.read()
# print(UserFile.filename)
with open(filepath, 'wb') as f:
f.write(content)
vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
response = {
'msg': 'seccess' if len(files)>0 else 'fail',
'status': 1 if len(files)>0 else 0,
'loaded_files': files
}
except Exception as err:
response["message"] = err
return response
@app.post("/qa")
async def get_answer(UserQuery: Query):
response = {
"status": 0,
"message": "",
"answer": None
}
global vs_path
history = []
try:
resp, history = local_doc_qa.get_knowledge_based_answer(query=UserQuery.query,
vs_path=vs_path,
chat_history=history)
if REPLY_WITH_SOURCE:
response["answer"] = resp
else:
response['answer'] = resp["result"]
response["message"] = 'successful'
response["status"] = 1
except Exception as err:
response["message"] = err
return response
if __name__ == "__main__":
uvicorn.run(
app='api:app',
host='0.0.0.0',
port=8100,
reload = True,
)
...@@ -9,6 +9,7 @@ import os ...@@ -9,6 +9,7 @@ import os
from configs.model_config import * from configs.model_config import *
import datetime import datetime
from typing import List from typing import List
from textsplitter import ChineseTextSplitter
# return top-k text chunk from vector store # return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6 VECTOR_SEARCH_TOP_K = 6
...@@ -16,8 +17,17 @@ VECTOR_SEARCH_TOP_K = 6 ...@@ -16,8 +17,17 @@ VECTOR_SEARCH_TOP_K = 6
# LLM input history length # LLM input history length
LLM_HISTORY_LEN = 3 LLM_HISTORY_LEN = 3
# Show reply with source text from input document
REPLY_WITH_SOURCE = True def load_file(filepath):
if filepath.lower().endswith(".pdf"):
loader = UnstructuredFileLoader(filepath)
textsplitter = ChineseTextSplitter(pdf=True)
docs = loader.load_and_split(textsplitter)
else:
loader = UnstructuredFileLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False)
docs = loader.load_and_split(text_splitter=textsplitter)
return docs
class LocalDocQA: class LocalDocQA:
...@@ -43,7 +53,9 @@ class LocalDocQA: ...@@ -43,7 +53,9 @@ class LocalDocQA:
self.top_k = top_k self.top_k = top_k
def init_knowledge_vector_store(self, def init_knowledge_vector_store(self,
filepath: str or List[str]): filepath: str or List[str],
vs_path: str or os.PathLike = None):
loaded_files = []
if isinstance(filepath, str): if isinstance(filepath, str):
if not os.path.exists(filepath): if not os.path.exists(filepath):
print("路径不存在") print("路径不存在")
...@@ -51,10 +63,11 @@ class LocalDocQA: ...@@ -51,10 +63,11 @@ class LocalDocQA:
elif os.path.isfile(filepath): elif os.path.isfile(filepath):
file = os.path.split(filepath)[-1] file = os.path.split(filepath)[-1]
try: try:
loader = UnstructuredFileLoader(filepath, mode="elements") docs = load_file(filepath)
docs = loader.load()
print(f"{file} 已成功加载") print(f"{file} 已成功加载")
except: loaded_files.append(filepath)
except Exception as e:
print(e)
print(f"{file} 未能成功加载") print(f"{file} 未能成功加载")
return None return None
elif os.path.isdir(filepath): elif os.path.isdir(filepath):
...@@ -62,25 +75,33 @@ class LocalDocQA: ...@@ -62,25 +75,33 @@ class LocalDocQA:
for file in os.listdir(filepath): for file in os.listdir(filepath):
fullfilepath = os.path.join(filepath, file) fullfilepath = os.path.join(filepath, file)
try: try:
loader = UnstructuredFileLoader(fullfilepath, mode="elements") docs += load_file(fullfilepath)
docs += loader.load()
print(f"{file} 已成功加载") print(f"{file} 已成功加载")
except: loaded_files.append(fullfilepath)
except Exception as e:
print(e)
print(f"{file} 未能成功加载") print(f"{file} 未能成功加载")
else: else:
docs = [] docs = []
for file in filepath: for file in filepath:
try: try:
loader = UnstructuredFileLoader(file, mode="elements") docs += load_file(file)
docs += loader.load()
print(f"{file} 已成功加载") print(f"{file} 已成功加载")
except: loaded_files.append(file)
except Exception as e:
print(e)
print(f"{file} 未能成功加载") print(f"{file} 未能成功加载")
if vs_path and os.path.isdir(vs_path):
vector_store = FAISS.load_local(vs_path, self.embeddings)
vector_store.add_documents(docs)
else:
if not vs_path:
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store = FAISS.from_documents(docs, self.embeddings) vector_store = FAISS.from_documents(docs, self.embeddings)
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
return vs_path if len(docs)>0 else None return vs_path if len(docs) > 0 else None, loaded_files
def get_knowledge_based_answer(self, def get_knowledge_based_answer(self,
query, query,
......
import os
import pinecone
from tqdm import tqdm
from langchain.llms import OpenAI
from langchain.text_splitter import SpacyTextSplitter
from langchain.document_loaders import TextLoader
from langchain.document_loaders import DirectoryLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Pinecone
#一些配置文件
openai_key="你的key" # 注册 openai.com 后获得
pinecone_key="你的key" # 注册 app.pinecone.io 后获得
pinecone_index="你的库" #app.pinecone.io 获得
pinecone_environment="你的Environment" # 登录pinecone后,在indexes页面 查看Environment
pinecone_namespace="你的Namespace" #如果不存在自动创建
#科学上网你懂得
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
#初始化pinecone
pinecone.init(
api_key=pinecone_key,
environment=pinecone_environment
)
index = pinecone.Index(pinecone_index)
#初始化OpenAI的embeddings
embeddings = OpenAIEmbeddings(openai_api_key=openai_key)
#初始化text_splitter
text_splitter = SpacyTextSplitter(pipeline='zh_core_web_sm',chunk_size=1000,chunk_overlap=200)
# 读取目录下所有后缀是txt的文件
loader = DirectoryLoader('../docs', glob="**/*.txt", loader_cls=TextLoader)
#读取文本文件
documents = loader.load()
# 使用text_splitter对文档进行分割
split_text = text_splitter.split_documents(documents)
try:
for document in tqdm(split_text):
# 获取向量并储存到pinecone
Pinecone.from_documents([document], embeddings, index_name=pinecone_index)
except Exception as e:
print(f"Error: {e}")
quit()
...@@ -6,7 +6,7 @@ import nltk ...@@ -6,7 +6,7 @@ import nltk
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
# return top-k text chunk from vector store # return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 10 VECTOR_SEARCH_TOP_K = 6
# LLM input history length # LLM input history length
LLM_HISTORY_LEN = 3 LLM_HISTORY_LEN = 3
...@@ -24,7 +24,7 @@ if __name__ == "__main__": ...@@ -24,7 +24,7 @@ if __name__ == "__main__":
vs_path = None vs_path = None
while not vs_path: while not vs_path:
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
vs_path = local_doc_qa.init_knowledge_vector_store(filepath) vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath)
history = [] history = []
while True: while True:
query = input("Input your question 请输入问题:") query = input("Input your question 请输入问题:")
......
...@@ -24,6 +24,13 @@ llm_model_dict = { ...@@ -24,6 +24,13 @@ llm_model_dict = {
# LLM model name # LLM model name
LLM_MODEL = "chatglm-6b" LLM_MODEL = "chatglm-6b"
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False
# LLM running device # LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
VS_ROOT_PATH = "./vector_store/"
UPLOAD_ROOT_PATH = "./content/"
## 变更日志
**[2023/04/15]**
1. 重构项目结构,在根目录下保留命令行 Demo [cli_demo.py](../cli_demo.py) 和 Web UI Demo [webui.py](../webui.py)
2. 对 Web UI 进行改进,修改为运行 Web UI 后首先按照 [configs/model_config.py](../configs/model_config.py) 默认选项加载模型,并增加报错提示信息等;
3. 对常见问题进行补充说明。
**[2023/04/12]**
1. 替换 Web UI 中的样例文件,避免出现 Ubuntu 中出现因文件编码无法读取的问题;
2. 替换`knowledge_based_chatglm.py`中的 prompt 模版,避免出现因 prompt 模版包含中英双语导致 chatglm 返回内容错乱的问题。
**[2023/04/11]**
1. 加入 Web UI V0.1 版本(感谢 [@liangtongt](https://github.com/liangtongt));
2. `README.md`中增加常见问题(感谢 [@calcitem](https://github.com/calcitem)[@bolongliu](https://github.com/bolongliu));
3. 增加 LLM 和 Embedding 模型运行设备是否可用`cuda``mps``cpu`的自动判断。
4.`knowledge_based_chatglm.py`中增加对`filepath`的判断,在之前支持单个文件导入的基础上,现支持单个文件夹路径作为输入,输入后将会遍历文件夹中各个文件,并在命令行中显示每个文件是否成功加载。
**[2023/04/09]**
1. 使用`langchain`中的`RetrievalQA`替代之前选用的`ChatVectorDBChain`,替换后可以有效减少提问 2-3 次后因显存不足而停止运行的问题;
2.`knowledge_based_chatglm.py`中增加`EMBEDDING_MODEL``VECTOR_SEARCH_TOP_K``LLM_MODEL``LLM_HISTORY_LEN``REPLY_WITH_SOURCE`参数值设置;
3. 增加 GPU 显存需求更小的`chatglm-6b-int4``chatglm-6b-int4-qe`作为 LLM 模型备选项;
4. 更正`README.md`中的代码错误(感谢 [@calcitem](https://github.com/calcitem))。
**[2023/04/07]**
1. 解决加载 ChatGLM 模型时发生显存占用为双倍的问题 (感谢 [@suc16](https://github.com/suc16)[@myml](https://github.com/myml)) ;
2. 新增清理显存机制;
3. 新增`nghuyong/ernie-3.0-nano-zh``nghuyong/ernie-3.0-base-zh`作为 Embedding 模型备选项,相比`GanymedeNil/text2vec-large-chinese`占用显存资源更少 (感谢 [@lastrei](https://github.com/lastrei))。
\ No newline at end of file
### 常见问题
Q1: 本项目支持哪些文件格式?
A1: 目前已测试支持 txt、docx、md、pdf 格式文件,更多文件格式请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)。目前已知文档中若含有特殊字符,可能存在文件无法加载的问题。
---
Q2: 执行 `pip install -r requirements.txt` 过程中,安装 `detectron2` 时发生报错怎么办?
A2: 如果不需要对 `pdf` 格式文件读取,可不安装 `detectron2`;如需对 `pdf` 文件进行高精度文本提取,建议按照如下方法安装:
```commandline
$ git clone https://github.com/facebookresearch/detectron2.git
$ cd detectron2
$ pip install -e .
```
---
Q3: 使用过程中 Python 包`nltk`发生了`Resource punkt not found.`报错,该如何解决?
A3: https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip 中的 `packages/tokenizers` 解压,放到 `nltk_data/tokenizers` 存储路径下。
`nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
---
Q4: 使用过程中 Python 包`nltk`发生了`Resource averaged_perceptron_tagger not found.`报错,该如何解决?
A4: 将 https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip 下载,解压放到 `nltk_data/taggers` 存储路径下。
`nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
---
Q5: 本项目可否在 colab 中运行?
A5: 可以尝试使用 chatglm-6b-int4 模型在 colab 中运行,需要注意的是,如需在 colab 中运行 Web UI,需将`webui.py``demo.queue(concurrency_count=3).launch(
server_name='0.0.0.0', share=False, inbrowser=False)`中参数`share`设置为`True`
---
Q6: 在 Anaconda 中使用 pip 安装包无效如何解决?
A6: 此问题是系统环境问题,详细见 [在Anaconda中使用pip安装包无效问题](在Anaconda中使用pip安装包无效问题.md)
---
Q7: 本项目中所需模型如何下载至本地?
A7: 本项目中使用的模型均为`huggingface.com`中可下载的开源模型,以默认选择的`chatglm-6b``text2vec-large-chinese`模型为例,下载模型可执行如下代码:
```shell
# 安装 git lfs
$ git lfs install
# 下载 LLM 模型
$ git clone https://huggingface.co/THUDM/chatglm-6b /your_path/chatglm-6b
# 下载 Embedding 模型
$ git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese /your_path/text2vec
# 模型需要更新时,可打开模型所在文件夹后拉取最新模型文件/代码
$ git pull
```
---
Q8: `huggingface.com`中模型下载速度较慢怎么办?
A8: 可使用本项目用到的模型权重文件百度网盘地址:
- ernie-3.0-base-zh.zip 链接: https://pan.baidu.com/s/1CIvKnD3qzE-orFouA8qvNQ?pwd=4wih
- ernie-3.0-nano-zh.zip 链接: https://pan.baidu.com/s/1Fh8fgzVdavf5P1omAJJ-Zw?pwd=q6s5
- text2vec-large-chinese.zip 链接: https://pan.baidu.com/s/1sMyPzBIXdEzHygftEoyBuA?pwd=4xs7
- chatglm-6b-int4-qe.zip 链接: https://pan.baidu.com/s/1DDKMOMHtNZccOOBGWIOYww?pwd=22ji
- chatglm-6b-int4.zip 链接: https://pan.baidu.com/s/1pvZ6pMzovjhkA6uPcRLuJA?pwd=3gjd
- chatglm-6b.zip 链接: https://pan.baidu.com/s/1B-MpsVVs1GHhteVBetaquw?pwd=djay
---
Q9: 下载完模型后,如何修改代码以执行本地模型?
A9: 模型下载完成后,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对`embedding_model_dict``llm_model_dict`参数进行修改,如把`llm_model_dict`
```json
embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec": "GanymedeNil/text2vec-large-chinese"
}
```
修改为
```json
embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec": "/Users/liuqian/Downloads/ChatGLM-6B/text2vec-large-chinese"
}
```
# 安装
## 环境检查
```shell
# 首先,确信你的机器安装了 Python 3.8 及以上版本
$ python --version
Python 3.8.13
# 如果低于这个版本,可使用conda安装环境
$ conda create -p /your_path/env_name python=3.8
# 激活环境
$ source activate /your_path/env_name
# 关闭环境
$ source deactivate /your_path/env_name
# 删除环境
$ conda env remove -p /your_path/env_name
```
## 项目依赖
```shell
# 拉取仓库
$ git clone https://github.com/imClumsyPanda/langchain-ChatGLM.git
# 安装依赖
$ pip install -r requirements.txt
```
注:使用 `langchain.document_loaders.UnstructuredFileLoader` 进行非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)
\ No newline at end of file
import json
from langchain.llms.base import LLM from langchain.llms.base import LLM
from typing import Optional, List from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch import torch
from configs.model_config import LLM_DEVICE from configs.model_config import LLM_DEVICE
...@@ -84,12 +85,26 @@ class ChatGLM(LLM): ...@@ -84,12 +85,26 @@ class ChatGLM(LLM):
def load_model(self, def load_model(self,
model_name_or_path: str = "THUDM/chatglm-6b", model_name_or_path: str = "THUDM/chatglm-6b",
llm_device=LLM_DEVICE, llm_device=LLM_DEVICE,
use_ptuning_v2=False,
device_map: Optional[Dict[str, int]] = None, device_map: Optional[Dict[str, int]] = None,
**kwargs): **kwargs):
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, model_name_or_path,
trust_remote_code=True trust_remote_code=True
) )
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if use_ptuning_v2:
try:
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
prefix_encoder_config = json.loads(prefix_encoder_file.read())
prefix_encoder_file.close()
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
except Exception:
print("加载PrefixEncoder config.json失败")
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"): if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
# 根据当前设备GPU数量决定是否进行多卡部署 # 根据当前设备GPU数量决定是否进行多卡部署
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
...@@ -97,6 +112,7 @@ class ChatGLM(LLM): ...@@ -97,6 +112,7 @@ class ChatGLM(LLM):
self.model = ( self.model = (
AutoModel.from_pretrained( AutoModel.from_pretrained(
model_name_or_path, model_name_or_path,
config=model_config,
trust_remote_code=True, trust_remote_code=True,
**kwargs) **kwargs)
.half() .half()
...@@ -111,12 +127,34 @@ class ChatGLM(LLM): ...@@ -111,12 +127,34 @@ class ChatGLM(LLM):
device_map = auto_configure_device_map(num_gpus) device_map = auto_configure_device_map(num_gpus)
self.model = dispatch_model(model, device_map=device_map) self.model = dispatch_model(model, device_map=device_map)
self.model = (
AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
trust_remote_code=True)
.half()
.cuda()
)
else: else:
self.model = ( self.model = (
AutoModel.from_pretrained( AutoModel.from_pretrained(
model_name_or_path, model_name_or_path,
config=model_config,
trust_remote_code=True) trust_remote_code=True)
.float() .float()
.to(llm_device) .to(llm_device)
) )
if use_ptuning_v2:
try:
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
except Exception:
print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval() self.model = self.model.eval()
如果使用了[p-tuning-v2](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)方式微调了模型,可以将得到的PrefixEndoer放入此文件夹。
只需要放入模型的*config.json**pytorch_model.bin*
并在加载模型时勾选 *"使用p-tuning-v2微调过的模型"*
\ No newline at end of file
import os
import subprocess
import re
def get_latest_tag():
output = subprocess.check_output(['git', 'tag'])
tags = output.decode('utf-8').split('\n')[:-1]
latest_tag = sorted(tags, key=lambda t: tuple(map(int, re.match(r'v(\d+)\.(\d+)\.(\d+)', t).groups())))[-1]
return latest_tag
def update_version_number(latest_tag, increment):
major, minor, patch = map(int, re.match(r'v(\d+)\.(\d+)\.(\d+)', latest_tag).groups())
if increment == 'X':
major += 1
minor, patch = 0, 0
elif increment == 'Y':
minor += 1
patch = 0
elif increment == 'Z':
patch += 1
new_version = f"v{major}.{minor}.{patch}"
return new_version
def main():
print("当前最近的Git标签:")
latest_tag = get_latest_tag()
print(latest_tag)
print("请选择要递增的版本号部分(X, Y, Z):")
increment = input().upper()
while increment not in ['X', 'Y', 'Z']:
print("输入错误,请输入X, Y或Z:")
increment = input().upper()
new_version = update_version_number(latest_tag, increment)
print(f"新的版本号为:{new_version}")
print("确认更新版本号并推送到远程仓库?(y/n)")
confirmation = input().lower()
if confirmation == 'y':
subprocess.run(['git', 'tag', new_version])
subprocess.run(['git', 'push', 'origin', new_version])
print("新版本号已创建并推送到远程仓库。")
else:
print("操作已取消。")
if __name__ == '__main__':
main()
from .chinese_text_splitter import *
\ No newline at end of file
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
def split_text(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论