-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathEvaluation.py
59 lines (36 loc) · 1.61 KB
/
Evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import logging
import Strings
from models.VAE import VAEModel
from models.STIBwoIR import STIBwoIRModel
from models.STIB import STIBModel
from models.CVAE import CVAEModel
from models.CVIB import CVIBModel
class Evaluation:
def __init__(self, args, dataset):
self.log = logging.getLogger(__name__)
self.args = args
self.dataset = dataset
def evaluate(self):
self.log.info("Evaluating %s" % self.args.model)
if self.args.model == Strings.VAE:
vae = VAEModel(dataset=self.dataset, z0_size=2, z1_size=1, y_size=2, x_size=2, args=self.args)
vae.buildModel()
vae.evaluateModel()
elif self.args.model == Strings.STIB_WO_IR:
stibWoReg = STIBwoIRModel(dataset=self.dataset, z0_size=2, z1_size=1, y_size=2, x_size=2, args=self.args)
stibWoReg.buildModel()
stibWoReg.evaluateModel()
elif self.args.model == Strings.STIB:
stib = STIBModel(dataset=self.dataset, z0_size=2, z1_size=1, y_size=2, x_size=2, args=self.args)
stib.buildModel()
stib.evaluateModel()
elif self.args.model == Strings.CVAE:
cvae = CVAEModel(dataset=self.dataset, z0_size=2, z1_size=1, y_size=2, x_size=2, args=self.args)
cvae.buildModel()
cvae.evaluateModel()
elif self.args.model == Strings.CVIB:
cvib = CVIBModel(dataset=self.dataset, z0_size=2, z1_size=1, y_size=2, x_size=2, args=self.args)
cvib.buildModel()
cvib.evaluateModel()
else:
self.log.error("Model to evaluate not found!")