提交 5d88f715 作者: imClumsyPanda

Merge branch 'dev'

FROM python:3.8
MAINTAINER "chatGLM"
COPY agent /chatGLM/agent
COPY chains /chatGLM/chains
COPY configs /chatGLM/configs
COPY content /chatGLM/content
COPY models /chatGLM/models
COPY nltk_data /chatGLM/content
COPY requirements.txt /chatGLM/
COPY cli_demo.py /chatGLM/
COPY webui.py /chatGLM/
WORKDIR /chatGLM
RUN pip install --user torch torchvision tensorboard cython -i https://pypi.tuna.tsinghua.edu.cn/simple
# RUN pip install --user 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
# RUN pip install --user 'git+https://github.com/facebookresearch/fvcore'
# install detectron2
# RUN git clone https://github.com/facebookresearch/detectron2
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/ --trusted-host pypi.tuna.tsinghua.edu.cn
CMD ["python","-u", "webui.py"]
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
🌍 [_READ THIS IN ENGLISH_](README_en.md) 🌍 [_READ THIS IN ENGLISH_](README_en.md)
🤖️ 一种利用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + [langchain](https://github.com/hwchase17/langchain) 实现的基于本地知识的 ChatGLM 应用。 🤖️ 一种利用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + [langchain](https://github.com/hwchase17/langchain) 实现的基于本地知识的 ChatGLM 应用。增加 [clue-ai/ChatYuan](https://github.com/clue-ai/ChatYuan) 项目的模型 [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 的支持。
💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai)[AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全部基于开源模型实现的本地知识问答应用。 💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai)[AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全部基于开源模型实现的本地知识问答应用。
✅ 本项目中 Embedding 选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署** ✅ 本项目中 Embedding 默认选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 默认选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署**
⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的`top k`个 -> 匹配出的文本作为上下文和问题一起添加到`prompt`中 -> 提交给`LLM`生成回答。 ⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的`top k`个 -> 匹配出的文本作为上下文和问题一起添加到`prompt`中 -> 提交给`LLM`生成回答。
...@@ -22,9 +22,7 @@ ...@@ -22,9 +22,7 @@
参见 [变更日志](docs/CHANGELOG.md) 参见 [变更日志](docs/CHANGELOG.md)
## 使用方式 ## 硬件需求
### 硬件需求
- ChatGLM-6B 模型硬件需求 - ChatGLM-6B 模型硬件需求
...@@ -38,9 +36,19 @@ ...@@ -38,9 +36,19 @@
本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。 本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
## Docker 部署
```commandline
$ docker build -t chatglm:v1.0 .
$ docker run -d --restart=always --name chatglm -p 7860:7860 -v /www/wwwroot/code/langchain-ChatGLM:/chatGLM chatglm
```
## 开发部署
### 软件需求 ### 软件需求
本项目已在 Python 3.8,CUDA 11.7 环境下完成测试。 本项目已在 Python 3.8 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。
### 从本地加载模型 ### 从本地加载模型
...@@ -123,6 +131,7 @@ Web UI 可以实现如下功能: ...@@ -123,6 +131,7 @@ Web UI 可以实现如下功能:
- [x] THUDM/chatglm-6b - [x] THUDM/chatglm-6b
- [x] THUDM/chatglm-6b-int4 - [x] THUDM/chatglm-6b-int4
- [x] THUDM/chatglm-6b-int4-qe - [x] THUDM/chatglm-6b-int4-qe
- [x] ClueAI/ChatYuan-large-v2
- [ ] Web UI - [ ] Web UI
- [x] 利用 gradio 实现 Web UI DEMO - [x] 利用 gradio 实现 Web UI DEMO
- [x] 添加输出内容及错误提示 - [x] 添加输出内容及错误提示
......
...@@ -42,7 +42,7 @@ async def get_local_doc_qa(): ...@@ -42,7 +42,7 @@ async def get_local_doc_qa():
@app.post("/file") @app.post("/file")
async def upload_file(UserFile: UploadFile=File(...)): async def upload_file(UserFile: UploadFile=File(...),):
global vs_path global vs_path
response = { response = {
"msg": None, "msg": None,
...@@ -67,7 +67,7 @@ async def upload_file(UserFile: UploadFile=File(...)): ...@@ -67,7 +67,7 @@ async def upload_file(UserFile: UploadFile=File(...)):
return response return response
@app.post("/qa") @app.post("/qa")
async def get_answer(UserQuery: Query): async def get_answer(query: str = ""):
response = { response = {
"status": 0, "status": 0,
"message": "", "message": "",
...@@ -76,7 +76,7 @@ async def get_answer(UserQuery: Query): ...@@ -76,7 +76,7 @@ async def get_answer(UserQuery: Query):
global vs_path global vs_path
history = [] history = []
try: try:
resp, history = local_doc_qa.get_knowledge_based_answer(query=UserQuery.query, resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path, vs_path=vs_path,
chat_history=history) chat_history=history)
if REPLY_WITH_SOURCE: if REPLY_WITH_SOURCE:
...@@ -95,9 +95,9 @@ async def get_answer(UserQuery: Query): ...@@ -95,9 +95,9 @@ async def get_answer(UserQuery: Query):
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run( uvicorn.run(
app='api:app', app=app,
host='0.0.0.0', host='0.0.0.0',
port=8100, port=8100,
reload = True, reload=True,
) )
...@@ -33,6 +33,7 @@ def load_file(filepath): ...@@ -33,6 +33,7 @@ def load_file(filepath):
class LocalDocQA: class LocalDocQA:
llm: object = None llm: object = None
embeddings: object = None embeddings: object = None
top_k: int = VECTOR_SEARCH_TOP_K
def init_cfg(self, def init_cfg(self,
embedding_model: str = EMBEDDING_MODEL, embedding_model: str = EMBEDDING_MODEL,
...@@ -49,9 +50,10 @@ class LocalDocQA: ...@@ -49,9 +50,10 @@ class LocalDocQA:
use_ptuning_v2=use_ptuning_v2) use_ptuning_v2=use_ptuning_v2)
self.llm.history_len = llm_history_len self.llm.history_len = llm_history_len
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], ) self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, model_kwargs={'device': embedding_device})
device=embedding_device) # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
# device=embedding_device)
self.top_k = top_k self.top_k = top_k
def init_knowledge_vector_store(self, def init_knowledge_vector_store(self,
......
...@@ -19,6 +19,7 @@ llm_model_dict = { ...@@ -19,6 +19,7 @@ llm_model_dict = {
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe", "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
"chatglm-6b-int4": "THUDM/chatglm-6b-int4", "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
"chatglm-6b": "THUDM/chatglm-6b", "chatglm-6b": "THUDM/chatglm-6b",
"chatyuan": "ClueAI/ChatYuan-large-v2",
} }
# LLM model name # LLM model name
......
...@@ -95,7 +95,7 @@ Q9: 下载完模型后,如何修改代码以执行本地模型? ...@@ -95,7 +95,7 @@ Q9: 下载完模型后,如何修改代码以执行本地模型?
A9: 模型下载完成后,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对`embedding_model_dict``llm_model_dict`参数进行修改,如把`llm_model_dict` A9: 模型下载完成后,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对`embedding_model_dict``llm_model_dict`参数进行修改,如把`llm_model_dict`
```json ```python
embedding_model_dict = { embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh",
...@@ -105,7 +105,7 @@ embedding_model_dict = { ...@@ -105,7 +105,7 @@ embedding_model_dict = {
修改为 修改为
```json ```python
embedding_model_dict = { embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh",
......
...@@ -72,14 +72,27 @@ class ChatGLM(LLM): ...@@ -72,14 +72,27 @@ class ChatGLM(LLM):
response, _ = self.model.chat( response, _ = self.model.chat(
self.tokenizer, self.tokenizer,
prompt, prompt,
history=self.history[-self.history_len:] if self.history_len>0 else [], history=self.history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature,
) )
torch_gc() torch_gc()
if stop is not None: if stop is not None:
response = enforce_stop_tokens(response, stop) response = enforce_stop_tokens(response, stop)
self.history = self.history+[[None, response]] self.history = self.history + [[None, response]]
return response
def chat(self,
prompt: str) -> str:
response, _ = self.model.chat(
self.tokenizer,
prompt,
history=self.history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature,
)
torch_gc()
self.history = self.history + [[None, response]]
return response return response
def load_model(self, def load_model(self,
...@@ -146,7 +159,8 @@ class ChatGLM(LLM): ...@@ -146,7 +159,8 @@ class ChatGLM(LLM):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float() self.model.transformer.prefix_encoder.float()
except Exception: except Exception as e:
print(e)
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval() self.model = self.model.eval()
langchain>=0.0.124 langchain>=0.0.146
transformers==4.27.1 transformers==4.27.1
unstructured[local-inference] unstructured[local-inference]
layoutparser[layoutmodels,tesseract] layoutparser[layoutmodels,tesseract]
...@@ -9,4 +9,4 @@ icetk ...@@ -9,4 +9,4 @@ icetk
cpm_kernels cpm_kernels
faiss-cpu faiss-cpu
gradio>=3.25.0 gradio>=3.25.0
detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2 #detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
\ No newline at end of file \ No newline at end of file
...@@ -17,10 +17,10 @@ LLM_HISTORY_LEN = 3 ...@@ -17,10 +17,10 @@ LLM_HISTORY_LEN = 3
def get_vs_list(): def get_vs_list():
if not os.path.exists(VS_ROOT_PATH): if not os.path.exists(VS_ROOT_PATH):
return [] return []
return ["新建知识库"] + os.listdir(VS_ROOT_PATH) return os.listdir(VS_ROOT_PATH)
vs_list = get_vs_list() vs_list = ["新建知识库"] + get_vs_list()
embedding_model_dict_list = list(embedding_model_dict.keys()) embedding_model_dict_list = list(embedding_model_dict.keys())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论