-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar10.py
85 lines (66 loc) · 2.45 KB
/
cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import warnings
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from transformers import logging
from dualvit.lightning.modulev2 import VALIDATION_METRIC_NAME
from dualvit.factory import TPModelFactory
from dualvit.constants import IMG_SIZE, LEARNING_RATE
from dualvit.lightning.data.cifar import CIFAR10MultiLabelDataModule
from dualvit.lightning.data.common import train_transform, test_transform
from dualvit.lightning.loss import BELMode
from dualvit.lightning.modulev2 import BroadFineModelLM
from pytorch_lightning.callbacks import ModelCheckpoint
logging.set_verbosity_warning()
warnings.filterwarnings("ignore")
# Data Module
# is_test = True then subset, else Full Dataset
datamodule = CIFAR10MultiLabelDataModule(
is_test=False,
train_transform=train_transform,
val_transform=test_transform,
)
datamodule.prepare_data()
datamodule.setup()
LOAD_CKPT = False
CKPT_PATH = "logs/cifar10/tpdualvit-p16-224/lightning_logs/version_1/checkpoints/modelepoch=31-val_af=0.957.ckpt"
NUM_FINE_CLASSES = 10
NUM_BROAD_CLASSES = 2
l_module = BroadFineModelLM(
model=TPModelFactory.get_model("CIFAR10"),
num_fine_outputs=NUM_FINE_CLASSES,
num_broad_outputs=NUM_BROAD_CLASSES,
lr=LEARNING_RATE,
loss_mode=BELMode.CLUSTER,
)
checkpoint_callback = ModelCheckpoint(
monitor=VALIDATION_METRIC_NAME, # Monitor the validation loss
filename="model" + "{epoch:02d}" + f"-{{{VALIDATION_METRIC_NAME}:.3f}}",
save_top_k=2,
mode="max", # 'max' -> More is monitor, the better
)
kwargs = {
"max_epochs": 100,
"accelerator": "gpu",
"gpus": 1,
"logger": CSVLogger(save_dir=f"logs/cifar10/tpdualvit-p16-{IMG_SIZE}"),
"deterministic": True,
"callbacks": [
LearningRateMonitor(logging_interval="step"),
TQDMProgressBar(refresh_rate=10),
checkpoint_callback,
],
"num_sanity_val_steps": 5,
"gradient_clip_val": 1,
}
if LOAD_CKPT:
kwargs = {**kwargs, "resume_from_checkpoint": CKPT_PATH}
trainer = Trainer(**kwargs)
if __name__ == "__main__":
if LOAD_CKPT:
# Ensure test dataset is defined else comment out
trainer.test(l_module, datamodule=datamodule, ckpt_path=CKPT_PATH)
trainer.fit(l_module, datamodule=datamodule)
# Ensure test dataset is defined else comment out
trainer.test(l_module, datamodule=datamodule)