提交 871a8716 作者: imClumsyPanda

update model_loader

上级 3712eec6
import gc
# import gc
import traceback
from queue import Queue
from threading import Thread
import threading
from typing import Optional, List, Dict, Any
# from threading import Thread
# import threading
from typing import Optional, List, Dict, Any, TypeVar, Deque
from collections import deque
import torch
import transformers
......@@ -12,13 +12,16 @@ from models.extensions.thread_with_exception import ThreadWithException
import models.shared as shared
class LimitedLengthDict(dict):
K = TypeVar('K')
V = TypeVar('V')
class LimitedLengthDict(Dict[K, V]):
def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen
self._keys = deque()
self._keys: Deque[K] = deque()
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
def __setitem__(self, key: K, value: V):
if key not in self:
if self.maxlen is not None and len(self) >= self.maxlen:
oldest_key = self._keys.popleft()
......
......@@ -139,9 +139,9 @@ class LoaderCheckPoint:
model = dispatch_model(model, device_map=self.device_map)
else:
print(
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "
"detected.\nFalling back to CPU mode.\n")
# print(
# "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "
# "detected.\nFalling back to CPU mode.\n")
model = (
AutoModel.from_pretrained(
checkpoint,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论