-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathresnext_util.py
28 lines (23 loc) · 1.25 KB
/
resnext_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch import nn
import resnext
def generate_resnext_model(mode, model_depth = 101, n_classes = 400, sample_size = 112, sample_duration = 16, resnet_shortcut = 'B', resnext_cardinality = 32):
assert mode in ['score', 'feature']
if mode == 'score':
last_fc = True
elif mode == 'feature':
last_fc = False
assert model_depth in [50, 101, 152]
if model_depth == 50:
model = resnext.resnet50(num_classes=n_classes, shortcut_type=resnet_shortcut, cardinality=resnext_cardinality,
sample_size=sample_size, sample_duration=sample_duration,
last_fc=last_fc)
elif model_depth == 101:
model = resnext.resnet101(num_classes=n_classes, shortcut_type=resnet_shortcut, cardinality=resnext_cardinality,
sample_size=sample_size, sample_duration=sample_duration,
last_fc=last_fc)
elif model_depth == 152:
model = resnext.resnet152(num_classes=n_classes, shortcut_type=resnet_shortcut, cardinality=resnext_cardinality,
sample_size=sample_size, sample_duration=sample_duration,
last_fc=last_fc)
return model