-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
48 lines (39 loc) · 1.71 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from resnet import *
import gpytorch
import math
import torch.nn as nn
import torch
class GaussianProcessLayer(gpytorch.models.AdditiveGridInducingVariationalGP):
def __init__(self, num_dim, grid_bounds=(-10., 10.), grid_size=64):
super(GaussianProcessLayer, self).__init__(grid_size=grid_size, grid_bounds=[grid_bounds],
num_dim=num_dim, mixing_params=False, sum_output=False)
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(
)
)
self.mean_module = gpytorch.means.ConstantMean()
self.grid_bounds = grid_bounds
def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean, covar)
class ResNetFeatureExtractor(nn.Module):
def __init__(self, resnet_type):
super(ResNetFeatureExtractor, self).__init__()
self.classifier = resnet_type(pretrained=True)
self.out_dim = self.classifier.out_dim
def forward(self, x):
features = self.classifier(x)
return features
class DKLModel(gpytorch.Module):
def __init__(self, feature_extractor, num_dim, grid_bounds=(-10., 10.)):
super(DKLModel, self).__init__()
self.feature_extractor = feature_extractor
self.gp_layer = GaussianProcessLayer(num_dim=num_dim, grid_bounds=grid_bounds)
self.grid_bounds = grid_bounds
self.num_dim = num_dim
def forward(self, x):
features = self.feature_extractor(x)
features = gpytorch.utils.grid.scale_to_bounds(features, self.grid_bounds[0], self.grid_bounds[1])
res = self.gp_layer(features)
return res