提交 5664d1ff 作者: littlepanda0716

add torch_gc to clear gpu cache in knowledge_based_chatglm.py

上级 3cbc6aa7
...@@ -2,6 +2,19 @@ from langchain.llms.base import LLM ...@@ -2,6 +2,19 @@ from langchain.llms.base import LLM
from typing import Optional, List from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", "THUDM/chatglm-6b",
...@@ -15,6 +28,7 @@ model = ( ...@@ -15,6 +28,7 @@ model = (
.cuda() .cuda()
) )
class ChatGLM(LLM): class ChatGLM(LLM):
max_token: int = 10000 max_token: int = 10000
temperature: float = 0.1 temperature: float = 0.1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论