-
Notifications
You must be signed in to change notification settings - Fork 7
/
models.py
126 lines (106 loc) · 6.01 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.autograd as autograd
import torch.nn as nn
from torch.autograd import Variable
import gpytorch
from gpytorch.kernels import RBFKernel, LinearKernel
class Policy(nn.Module):
# Deep neural policy network
def __init__(self, num_inputs, num_outputs):
super(Policy, self).__init__()
self.affine1 = nn.Linear(num_inputs, 64)
self.affine2 = nn.Linear(64, 64)
self.action_mean = nn.Linear(64, num_outputs)
self.action_mean.weight.data.mul_(0.1)
self.action_mean.bias.data.mul_(0.0)
self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs))
def forward(self, x):
x = torch.tanh(self.affine1(x))
x = torch.tanh(self.affine2(x))
action_mean = self.action_mean(x)
action_log_std = self.action_log_std.expand_as(action_mean)
action_std = torch.exp(action_log_std)
return action_mean, action_log_std, action_std
class FeatureExtractor(nn.Module):
# Common feature extractor shared between the state-value V(s) and action-value Q(s,a) function approximations
def __init__(self, num_inputs, num_outputs):
super(FeatureExtractor, self).__init__()
self.affine1 = nn.Linear(num_inputs, 64)
self.affine2 = nn.Linear(64, 48)
self.affine3 = nn.Linear(48, num_outputs)
def forward(self, x):
x = torch.tanh(self.affine1(x))
x = torch.tanh(self.affine2(x))
x = torch.tanh(self.affine3(x))
return x
class Value(gpytorch.models.ExactGP, nn.Module):
# Monte-Carlo PG estimator :- Only State-value V(s) function approximation, i.e. feature extractor + value head
# Bayesian Quadrature PG estimator :- Both state-value V(s) and action-value Q(s,a) function approximation, i.e. feature extractor + value head + GP head
def __init__(self,
NN_num_inputs,
pg_estimator,
fisher_num_inputs=None,
gp_likelihood=None):
# fisher_num_inputs is same as svd_low_rank, because of the linear approximation of the Fisher kernel through FastSVD.
if pg_estimator == 'MC':
nn.Module.__init__(self)
else:
gpytorch.models.ExactGP.__init__(self, None, None, gp_likelihood)
self.NN_num_inputs = NN_num_inputs
NN_num_outputs = 10
self.feature_extractor = FeatureExtractor(NN_num_inputs,
NN_num_outputs)
# value_head is used for computing the state-value function approximation V(s) and subsequently GAE estimates
self.value_head = nn.Linear(NN_num_outputs, 1)
self.value_head.weight.data.mul_(0.1)
self.value_head.bias.data.mul_(0.0)
if pg_estimator == 'BQ':
grid_size = 128
# Like value_head, the following code constructs the GP head for action-value function approximation Q(s,a)
# Note that both V(s) and Q(s,a) share the same feature extractor for the state-values "s".
self.mean_module = gpytorch.means.ConstantMean()
# First NN_num_outputs indices of GP's input correspond to the state_kernel
state_kernel_active_dims = torch.tensor(list(
range(NN_num_outputs)))
# [NN_num_outputs, GP_input.shape[1]-1] indices of GP's input correspond to the fisher_kernel
fisher_kernel_active_dims = torch.tensor(
list(range(fisher_num_inputs +
NN_num_outputs))[NN_num_outputs:])
self.covar_module_2 = LinearKernel(active_dims=fisher_kernel_active_dims)
self.covar_module_1 = gpytorch.kernels.AdditiveStructureKernel(
gpytorch.kernels.ScaleKernel(
gpytorch.kernels.GridInterpolationKernel(
RBFKernel(ard_num_dims=1),
grid_size=grid_size,
num_dims=1)),
num_dims=NN_num_outputs,
active_dims=state_kernel_active_dims)
def nn_forward(self, x):
# Invokes the value_head for computing the state value function V(s), subsequently used for computing GAE estimates
extracted_features = self.feature_extractor(x[:, :self.NN_num_inputs])
state_value_estimate = self.value_head(extracted_features)
return state_value_estimate
def forward(self,
x,
state_multiplier,
fisher_multiplier,
only_state_kernel=False):
# Invokes the GP head for computing the action value function Q(s,a), although Q(s,a) is never explicitly computed.
# Instead, we implicitly compute (K + sigma^2 I)^{-1}*A^{GAE} which subsequently provides the BQ's PG estimate.
extracted_features = self.feature_extractor(x[:, :self.NN_num_inputs])
extracted_features = gpytorch.utils.grid.scale_to_bounds(
extracted_features, -1, 1)
if only_state_kernel:
# Used for computing inverse vanilla gradient covariance (Cov^{BQ})^{-1} or natural gradient covariance (Cov^{NBQ})^{-1}
mean_x = self.mean_module(extracted_features)
# Implicitly computes (c_1 K_s + sigma^2 I) which can be used for efficiently computing the MVM (c_1 K_s + sigma^2 I)^{-1}*v
covar_x = gpytorch.lazy.ConstantMulLazyTensor(
self.covar_module_1(extracted_features), state_multiplier)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
GP_input = torch.cat([extracted_features, x[:, self.NN_num_inputs:]],
1)
mean_x = self.mean_module(GP_input)
# Implicitly computes (c_1 K_s + c_2 K_f + sigma^2 I) which can be used for efficiently computing the MVM (c_1 K_s + c_2 K_f + sigma^2 I)^{-1}*v
covar_x = gpytorch.lazy.ConstantMulLazyTensor(self.covar_module_1(GP_input), state_multiplier) + \
gpytorch.lazy.ConstantMulLazyTensor(self.covar_module_2(GP_input), fisher_multiplier)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)