-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauto.py
146 lines (127 loc) · 4.88 KB
/
auto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from collections import OrderedDict
from utils import Struct
from dataset import DyRDataset, StaticKGDataset
from trainer import BaseTrainer, MixTrainer
from tester import BaseTester
from model.dyr import DyRMLP, BaseDyR
from model.baselines import TransE, DistMult, SimplE, RGCNLinkPredict
DATASET_MAPPING = OrderedDict(
[
("dyr", DyRDataset),
("transe", StaticKGDataset),
("distmult", StaticKGDataset),
("simple", StaticKGDataset),
("basedyr", DyRDataset),
('rgcn', StaticKGDataset)
]
)
MODEL_MAPPING = OrderedDict(
[
("dyr", DyRMLP),
('transe', TransE),
("distmult", DistMult),
("simple", SimplE),
('basedyr', BaseDyR),
('rgcn', RGCNLinkPredict)
]
)
TRAINER_MAPPING = OrderedDict(
[
("dyr", BaseTrainer),
('transe', MixTrainer),
('distmult', MixTrainer),
('simple', MixTrainer),
('basedyr', BaseTrainer),
('rgcn', MixTrainer)
]
)
TESTER_MAPPING = OrderedDict(
[
("dyr", BaseTester),
('transe', BaseTester),
('distmult', BaseTester),
('simple', BaseTester),
('basedyr', BaseTester),
('rgcn', BaseTester)
]
)
class AutoDataset:
r"""
This is a generic dataset class that will be instantiated as one of the dataset classes of the library
when created with the :meth:`~auto.AutoDataset.for_model` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoDataset is designed to be instantiated "
"using the `AutoDataset.for_model(pretrained_model_name_or_path)` method."
)
@classmethod
def for_model(cls, params, device):
if params.model in DATASET_MAPPING:
dataset_class = DATASET_MAPPING[params.model]
return dataset_class(params, device)
raise ValueError(
f"Unrecognized model identifier: {params.model}. Should contain one of {', '.join(DATASET_MAPPING.keys())}"
)
class AutoModel:
r"""
This is a generic model class that will be instantiated as one of the model classes of the library
when created with the :meth:`~auto.AutoModel.for_model` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoModel is designed to be instantiated "
"using the `AutoModel.for_model(pretrained_model_name_or_path)` method."
)
@classmethod
def for_model(cls, params, device):
if params.model in MODEL_MAPPING:
model_class = MODEL_MAPPING[params.model]
if params.model in ["naive", "capse"]:
args = Struct(**params.decoder)
return model_class(args, device=device)
else:
return model_class(params, device=device)
raise ValueError(
f"Unrecognized model identifier: {params.model}. Should contain one of {', '.join(MODEL_MAPPING.keys())}"
)
class AutoTrainer:
r"""
This is a generic trainer class that will be instantiated as one of the trainer classes of the library
when created with the :meth:`~auto.AutoTrainer.for_model` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoTrainer is designed to be instantiated "
"using the `AutoTrainer.for_model(pretrained_model_name_or_path)` method."
)
@classmethod
def for_model(cls, dataset, params, model, device):
if params.model in TRAINER_MAPPING:
trainer_class = TRAINER_MAPPING[params.model]
return trainer_class(dataset, params, model, device)
raise ValueError(
f"Unrecognized trainer identifier: {params.model}. Should contain one of {', '.join(TRAINER_MAPPING.keys())}"
)
class AutoTester:
r"""
This is a generic tester class that will be instantiated as one of the tester classes of the library
when created with the :meth:`~auto.AutoTester.for_model` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoTester is designed to be instantiated "
"using the `AutoTester.for_model(pretrained_model_name_or_path)` method."
)
@classmethod
def for_model(cls, dataset, params, model, device):
if params.model in TESTER_MAPPING:
tester_class = TESTER_MAPPING[params.model]
return tester_class(dataset, params, model, device)
raise ValueError(
f"Unrecognized tester identifier: {params.model}. Should contain one of {', '.join(TESTER_MAPPING.keys())}"
)