提交 dc0cdfba 作者: imClumsyPanda

Merge branch 'dev'

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Other files
output/*
log/*
.chroma
vector_store/*
\ No newline at end of file
from .chatglm_with_shared_memory_openai_llm import *
\ No newline at end of file
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.document_loaders import UnstructuredFileLoader
from models.chatglm_llm import ChatGLM
import sentence_transformers
import os
from configs.model_config import *
import datetime
from typing import List
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 10
# LLM input history length
LLM_HISTORY_LEN = 3
# Show reply with source text from input document
REPLY_WITH_SOURCE = True
class LocalDocQA:
llm: object = None
embeddings: object = None
def init_cfg(self,
embedding_model: str = EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
llm_history_len: int = LLM_HISTORY_LEN,
llm_model: str = LLM_MODEL,
llm_device=LLM_DEVICE,
top_k=VECTOR_SEARCH_TOP_K,
):
self.llm = ChatGLM()
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
llm_device=llm_device)
self.llm.history_len = llm_history_len
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
device=embedding_device)
self.top_k = top_k
def init_knowledge_vector_store(self,
filepath: str or List[str]):
if isinstance(filepath, str):
if not os.path.exists(filepath):
print("路径不存在")
return None
elif os.path.isfile(filepath):
file = os.path.split(filepath)[-1]
try:
loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load()
print(f"{file} 已成功加载")
except:
print(f"{file} 未能成功加载")
return None
elif os.path.isdir(filepath):
docs = []
for file in os.listdir(filepath):
fullfilepath = os.path.join(filepath, file)
try:
loader = UnstructuredFileLoader(fullfilepath, mode="elements")
docs += loader.load()
print(f"{file} 已成功加载")
except:
print(f"{file} 未能成功加载")
else:
docs = []
for file in filepath:
try:
loader = UnstructuredFileLoader(file, mode="elements")
docs += loader.load()
print(f"{file} 已成功加载")
except:
print(f"{file} 未能成功加载")
vector_store = FAISS.from_documents(docs, self.embeddings)
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store.save_local(vs_path)
return vs_path if len(docs)>0 else None
def get_knowledge_based_answer(self,
query,
vs_path,
chat_history=[], ):
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
已知内容:
{context}
问题:
{question}"""
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
self.llm.history = chat_history
vector_store = FAISS.load_local(vs_path, self.embeddings)
knowledge_chain = RetrievalQA.from_llm(
llm=self.llm,
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
prompt=prompt
)
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
knowledge_chain.return_source_documents = True
result = knowledge_chain({"query": query})
self.llm.history[-1][0] = query
return result, self.llm.history
from configs.model_config import *
from chains.local_doc_qa import LocalDocQA
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 10
# LLM input history length
LLM_HISTORY_LEN = 3
# Show reply with source text from input document
REPLY_WITH_SOURCE = True
if __name__ == "__main__":
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=LLM_MODEL,
embedding_model=EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
llm_history_len=LLM_HISTORY_LEN,
top_k=VECTOR_SEARCH_TOP_K)
vs_path = None
while not vs_path:
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
history = []
while True:
query = input("Input your question 请输入问题:")
resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=history)
if REPLY_WITH_SOURCE:
print(resp)
else:
print(resp["result"])
import torch.cuda
import torch.backends
embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec": "GanymedeNil/text2vec-large-chinese",
}
# Embedding model name
EMBEDDING_MODEL = "text2vec"
# Embedding running device
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# supported LLM models
llm_model_dict = {
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
"chatglm-6b-int4": "THUDM/chatglm-6b-int4",
"chatglm-6b": "THUDM/chatglm-6b",
}
# LLM model name
LLM_MODEL = "chatglm-6b"
# LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
## Issue with Installing Packages Using pip in Anaconda
## Problem
Recently, when running open-source code, I encountered an issue: after creating a virtual environment with conda and switching to the new environment, using pip to install packages would be "ineffective." Here, "ineffective" means that the packages installed with pip are not in this new environment.
------
## Analysis
1. First, create a test environment called test: `conda create -n test`
2. Activate the test environment: `conda activate test`
3. Use pip to install numpy: `pip install numpy`. You'll find that numpy already exists in the default environment.
```powershell
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: numpy in c:\programdata\anaconda3\lib\site-packages (1.20.3)
```
4. Check the information of pip: `pip show pip`
```powershell
Name: pip
Version: 21.2.4
Summary: The PyPA recommended tool for installing Python packages.
Home-page: https://pip.pypa.io/
Author: The pip developers
Author-email: distutils-sig@python.org
License: MIT
Location: c:\programdata\anaconda3\lib\site-packages
Requires:
Required-by:
```
5. We can see that the current pip is in the default conda environment. This explains why the package is not in the new virtual environment when we directly use pip to install packages - because the pip being used belongs to the default environment, the installed package either already exists or is installed directly into the default environment.
------
## Solution
1. We can directly use the conda command to install new packages, but sometimes conda may not have certain packages/libraries, so we still need to use pip to install.
2. We can first use the conda command to install the pip package for the current virtual environment, and then use pip to install new packages.
```powershell
# Use conda to install the pip package
(test) PS C:\Users\Administrator> conda install pip
Collecting package metadata (current_repodata.json): done
Solving environment: done
....
done
# Display the information of the current pip, and find that pip is in the test environment
(test) PS C:\Users\Administrator> pip show pip
Name: pip
Version: 21.2.4
Summary: The PyPA recommended tool for installing Python packages.
Home-page: https://pip.pypa.io/
Author: The pip developers
Author-email: distutils-sig@python.org
License: MIT
Location: c:\programdata\anaconda3\envs\test\lib\site-packages
Requires:
Required-by:
# Now use pip to install the numpy package, and it is installed successfully
(test) PS C:\Users\Administrator> pip install numpy
Looking in indexes:
https://pypi.tuna.tsinghua.edu.cn/simple
Collecting numpy
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/4b/23/140ec5a509d992fe39db17200e96c00fd29603c1531ce633ef93dbad5e9e/numpy-1.22.2-cp39-cp39-win_amd64.whl (14.7 MB)
Installing collected packages: numpy
Successfully installed numpy-1.22.2
# Use pip list to view the currently installed packages, no problem
(test) PS C:\Users\Administrator> pip list
Package Version
------------ ---------
certifi 2021.10.8
numpy 1.22.2
pip 21.2.4
setuptools 58.0.4
wheel 0.37.1
wincertstore 0.2
```
## Supplement
1. The reason I didn't notice this problem before might be because the packages installed in the virtual environment were of a specific version, which overwrote the packages in the default environment. The main issue was actually a lack of careful observation:), otherwise, I could have noticed `Successfully uninstalled numpy-xxx` **default version** and `Successfully installed numpy-1.20.3` **specified version**.
2. During testing, I found that if the Python version is specified when creating a new package, there shouldn't be this issue. I guess this is because pip will be installed in the virtual environment, while in our case, including pip, no packages were installed, so the default environment's pip was used.
3. There's a question: I should have specified the Python version when creating a new virtual environment before, but I still used the default environment's pip package. However, I just couldn't reproduce the issue successfully on two different machines, which led to the second point mentioned above.
4. After encountering the problem mentioned in point 3, I solved it by using `python -m pip install package-name`, adding `python -m` before pip. As for why, you can refer to the answer on [StackOverflow](https://stackoverflow.com/questions/41060382/using-pip-to-install-packages-to-anaconda-environment):
>1. If you have a non-conda pip as your default pip but conda python as your default python (as below):
>
>```shell
>>which -a pip
>/home/<user>/.local/bin/pip
>/home/<user>/.conda/envs/newenv/bin/pip
>/usr/bin/pip
>
>>which -a python
>/home/<user>/.conda/envs/newenv/bin/python
>/usr/bin/python
>```
>
>2. Then, instead of calling `pip install <package>` directly, you can use the module flag -m in python so that it installs with the anaconda python
>
>```shell
>python -m pip install <package>
>```
>
>3. This will install the package to the anaconda library directory rather than the library directory associated with the (non-anaconda) pip
>4. The reason for doing this is as follows: the pip command references a specific pip file/shortcut (which -a pip will tell you which one). Similarly, the python command references a specific python file (which -a python will tell you which one). For one reason or another, these two commands can become out of sync, so your "default" pip is in a different folder than your default python and therefore is associated with different versions of python.
>5. In contrast, the python -m pip construct does not use the shortcut that the pip command points to. Instead, it asks python to find its pip version and use that version to install a package.
\ No newline at end of file
## 在 Anaconda 中使用 pip 安装包无效问题
## 在 Anaconda 中使用 pip 安装包无效问题
## 问题
最近在跑开源代码的时候遇到的问题:使用 conda 创建虚拟环境并切换到新的虚拟环境后,再使用 pip 来安装包会“无效”。这里的“无效”指的是使用 pip 安装的包不在这个新的环境中。
------
## 分析
1、首先创建一个测试环境 test,`conda create -n test`
2、激活该测试环境,`conda activate test`
3、使用 pip 安装 numpy,`pip install numpy`,会发现 numpy 已经存在默认的环境中
```powershell
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: numpy in c:\programdata\anaconda3\lib\site-packages (1.20.3)
```
4、这时候看一下 pip 的信息,`pip show pip`
```powershell
Name: pip
Version: 21.2.4
Summary: The PyPA recommended tool for installing Python packages.
Home-page: https://pip.pypa.io/
Author: The pip developers
Author-email: distutils-sig@python.org
License: MIT
Location: c:\programdata\anaconda3\lib\site-packages
Requires:
Required-by:
```
5、可以发现当前 pip 是在默认的 conda 环境中。这也就解释了当我们直接使用 pip 安装包时为什么包不在这个新的虚拟环境中,因为使用的 pip 属于默认环境,安装的包要么已经存在,要么直接装到默认环境中去了。
------
## 解决
1、我们可以直接使用 conda 命令安装新的包,但有些时候 conda 可能没有某些包/库,所以还是得用 pip 安装
2、我们可以先使用 conda 命令为当前虚拟环境安装 pip 包,再使用 pip 安装新的包
```powershell
# 使用 conda 安装 pip 包
(test) PS C:\Users\Administrator> conda install pip
Collecting package metadata (current_repodata.json): done
Solving environment: done
....
done
# 显示当前 pip 的信息,发现 pip 在测试环境 test 中
(test) PS C:\Users\Administrator> pip show pip
Name: pip
Version: 21.2.4
Summary: The PyPA recommended tool for installing Python packages.
Home-page: https://pip.pypa.io/
Author: The pip developers
Author-email: distutils-sig@python.org
License: MIT
Location: c:\programdata\anaconda3\envs\test\lib\site-packages
Requires:
Required-by:
# 再使用 pip 安装 numpy 包,成功安装
(test) PS C:\Users\Administrator> pip install numpy
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting numpy
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/4b/23/140ec5a509d992fe39db17200e96c00fd29603c1531ce633ef93dbad5e9e/numpy-1.22.2-cp39-cp39-win_amd64.whl (14.7 MB)
Installing collected packages: numpy
Successfully installed numpy-1.22.2
# 使用 pip list 查看当前安装的包,没有问题
(test) PS C:\Users\Administrator> pip list
Package Version
------------ ---------
certifi 2021.10.8
numpy 1.22.2
pip 21.2.4
setuptools 58.0.4
wheel 0.37.1
wincertstore 0.2
```
------
## 补充
1、之前没有发现这个问题可能时因为在虚拟环境中安装的包是指定版本的,覆盖了默认环境中的包。其实主要还是观察不仔细:),不然可以发现 `Successfully uninstalled numpy-xxx`【默认版本】 以及 `Successfully installed numpy-1.20.3`【指定版本】
2、测试时发现如果在新建包的时候指定了 python 版本的话应该是没有这个问题的,猜测时因为会在虚拟环境中安装好 pip ,而我们这里包括 pip 在内啥包也没有装,所以使用的是默认环境的 pip
3、有个问题,之前我在创建新的虚拟环境时应该指定了 python 版本,但还是使用的默认环境的 pip 包,但是刚在在两台机器上都没有复现成功,于是有了上面的第 2 点
4、出现了第 3 点的问题后,我当时是使用 `python -m pip install package-name` 解决的,在 pip 前面加上了 python -m。至于为什么,可以参考 [StackOverflow](https://stackoverflow.com/questions/41060382/using-pip-to-install-packages-to-anaconda-environment) 上的回答:
> 1、如果你有一个非 conda 的 pip 作为你的默认 pip,但是 conda 的 python 是你的默认 python(如下):
>
> ```shell
> >which -a pip
> /home/<user>/.local/bin/pip
> /home/<user>/.conda/envs/newenv/bin/pip
> /usr/bin/pip
>
> >which -a python
> /home/<user>/.conda/envs/newenv/bin/python
> /usr/bin/python
> ```
>
> 2、然后,而不是直接调用 `pip install <package>`,你可以在 python 中使用模块标志 -m,以便它使用 anaconda python 进行安装
>
> ```shell
>python -m pip install <package>
> ```
>
> 3、这将把包安装到 anaconda 库目录,而不是与(非anaconda) pip 关联的库目录
>
> 4、这样做的原因如下:命令 pip 引用了一个特定的 pip 文件 / 快捷方式(which -a pip 会告诉你是哪一个)。类似地,命令 python 引用一个特定的 python 文件(which -a python 会告诉你是哪个)。由于这样或那样的原因,这两个命令可能变得不同步,因此你的“默认” pip 与你的默认 python 位于不同的文件夹中,因此与不同版本的 python 相关联。
>
> 5、与此相反,python -m pip 构造不使用 pip 命令指向的快捷方式。相反,它要求 python 找到它的pip 版本,并使用该版本安装一个包。
-
img/ui1.png

51.4 KB | W: | H:

img/ui1.png

344.8 KB | W: | H:

img/ui1.png
img/ui1.png
img/ui1.png
img/ui1.png
  • 2-up
  • Swipe
  • Onion skin
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.document_loaders import UnstructuredFileLoader
from chatglm_llm import ChatGLM
import sentence_transformers
import torch
import os
import readline
# Global Parameters
EMBEDDING_MODEL = "text2vec"
VECTOR_SEARCH_TOP_K = 6
LLM_MODEL = "chatglm-6b"
LLM_HISTORY_LEN = 3
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Show reply with source text from input document
REPLY_WITH_SOURCE = True
embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec": "GanymedeNil/text2vec-large-chinese",
}
llm_model_dict = {
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
"chatglm-6b-int4": "THUDM/chatglm-6b-int4",
"chatglm-6b": "THUDM/chatglm-6b",
}
def init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN, V_SEARCH_TOP_K=6):
global chatglm, embeddings, VECTOR_SEARCH_TOP_K
VECTOR_SEARCH_TOP_K = V_SEARCH_TOP_K
chatglm = ChatGLM()
chatglm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL])
chatglm.history_len = LLM_HISTORY_LEN
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],)
embeddings.client = sentence_transformers.SentenceTransformer(embeddings.model_name,
device=DEVICE)
def init_knowledge_vector_store(filepath:str):
if not os.path.exists(filepath):
print("路径不存在")
return None
elif os.path.isfile(filepath):
file = os.path.split(filepath)[-1]
try:
loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load()
print(f"{file} 已成功加载")
except:
print(f"{file} 未能成功加载")
return None
elif os.path.isdir(filepath):
docs = []
for file in os.listdir(filepath):
fullfilepath = os.path.join(filepath, file)
try:
loader = UnstructuredFileLoader(fullfilepath, mode="elements")
docs += loader.load()
print(f"{file} 已成功加载")
except:
print(f"{file} 未能成功加载")
vector_store = FAISS.from_documents(docs, embeddings)
return vector_store
def get_knowledge_based_answer(query, vector_store, chat_history=[]):
global chatglm, embeddings
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
已知内容:
{context}
问题:
{question}"""
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
chatglm.history = chat_history
knowledge_chain = RetrievalQA.from_llm(
llm=chatglm,
retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
prompt=prompt
)
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
knowledge_chain.return_source_documents = True
result = knowledge_chain({"query": query})
chatglm.history[-1][0] = query
return result, chatglm.history
if __name__ == "__main__":
init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN)
vector_store = None
while not vector_store:
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
vector_store = init_knowledge_vector_store(filepath)
history = []
while True:
query = input("Input your question 请输入问题:")
resp, history = get_knowledge_based_answer(query=query,
vector_store=vector_store,
chat_history=history)
if REPLY_WITH_SOURCE:
print(resp)
else:
print(resp["result"])
from .chatglm_llm import *
\ No newline at end of file
...@@ -3,8 +3,9 @@ from typing import Optional, List ...@@ -3,8 +3,9 @@ from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import torch import torch
from configs.model_config import LLM_DEVICE
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" DEVICE = LLM_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None DEVICE_ID = "0" if torch.cuda.is_available() else None
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
...@@ -48,12 +49,14 @@ class ChatGLM(LLM): ...@@ -48,12 +49,14 @@ class ChatGLM(LLM):
self.history = self.history+[[None, response]] self.history = self.history+[[None, response]]
return response return response
def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b"): def load_model(self,
model_name_or_path: str = "THUDM/chatglm-6b",
llm_device=LLM_DEVICE):
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, model_name_or_path,
trust_remote_code=True trust_remote_code=True
) )
if torch.cuda.is_available(): if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
self.model = ( self.model = (
AutoModel.from_pretrained( AutoModel.from_pretrained(
model_name_or_path, model_name_or_path,
...@@ -61,19 +64,12 @@ class ChatGLM(LLM): ...@@ -61,19 +64,12 @@ class ChatGLM(LLM):
.half() .half()
.cuda() .cuda()
) )
elif torch.backends.mps.is_available():
self.model = (
AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=True)
.float()
.to('mps')
)
else: else:
self.model = ( self.model = (
AutoModel.from_pretrained( AutoModel.from_pretrained(
model_name_or_path, model_name_or_path,
trust_remote_code=True) trust_remote_code=True)
.float() .float()
.to(llm_device)
) )
self.model = self.model.eval() self.model = self.model.eval()
...@@ -9,3 +9,4 @@ icetk ...@@ -9,3 +9,4 @@ icetk
cpm_kernels cpm_kernels
faiss-cpu faiss-cpu
gradio>=3.25.0 gradio>=3.25.0
detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论