-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain_cifar.py
110 lines (98 loc) · 3.12 KB
/
train_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torchvision.transforms as transforms
from torchvision import datasets
import numpy as np
from traintest import train
import build_model as build
torch.random.manual_seed(1)
# Dataset Config -------------------------------------------
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
data_transform = {
'train': transforms.Compose([
transforms.ToTensor(),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean, std)
]),
'val': transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=None),
transforms.Normalize(mean, std)
])
}
status = False
# Todo: Train on CIFAR10
train_dataset = datasets.CIFAR10(
root='./datasets/torch_cifar10/',
train=True,
transform=data_transform['train'],
download=status)
val_dataset = datasets.CIFAR10(
root='./datasets/torch_cifar10/',
train=False,
transform=data_transform['val'],
download=status)
# Todo: Train on CIFAR100
# train_dataset = datasets.CIFAR100(
# root='./datasets/torch_cifar100/',
# train=True,
# transform=data_transform['train'],
# download=status)
# val_dataset = datasets.CIFAR100(
# root='./datasets/torch_cifar100/',
# train=False,
# transform=data_transform['val'],
# download=status)
batch_size = 12
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True)
if __name__ == '__main__':
dataset = 'cifar10'
swin_type = 'tiny'
reg_type, reg_lambda = 'l1', 1e-5
device = torch.device('cuda')
epochs = 1
show_per = 200
ltoken_num, ltoken_dims = 49, 256
lf = 2
model = build.buildSparseSwin(
image_resolution=224,
swin_type=swin_type,
num_classes=10,
ltoken_num=ltoken_num,
ltoken_dims=ltoken_dims,
num_heads=16,
qkv_bias=True,
lf=lf,
attn_drop_prob=.0,
lin_drop_prob=.0,
freeze_12=False,
device=device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
criterion = torch.nn.CrossEntropyLoss()
train(
train_loader,
swin_type,
dataset,
epochs,
model,
lf,
ltoken_num,
optimizer,
criterion,
device,
show_per=show_per,
reg_type=reg_type,
reg_lambda=reg_lambda,
validation=val_loader)