-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpnn.py
179 lines (154 loc) · 5.44 KB
/
pnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
""" Basic implementation of a Probabilistic Neural Network (PNN). This is a
neural network that outputs the mean and variance of a standard normal.
"""
from typing import Callable, Sequence, Tuple
import numpy as np
from pytorch_lightning import LightningModule
import torch
import torch.nn.functional as F
from simple_uq.util.mlp import MLP
class PNN(LightningModule):
"""Probabilistic neural network (PNN) outputting Gaussian distribution.
This model is implemented as a two headed neural network. The two heads
output the mean and logvariance of a multi-variate normal.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
# Parameters for the encoder network.
encoder_hidden_sizes: Sequence[int],
encoder_output_dim: int,
# Parameters for mean and logvar heads.
mean_hidden_sizes: Sequence[int],
logvar_hidden_sizes: Sequence[int],
hidden_activation: Callable[[torch.Tensor], torch.Tensor] = F.relu,
learning_rate: float = 1e-3,
):
"""Constructor.
Args:
input_dim: Dimension of input data.
output_dim: Dimesnion of data outputted.
hidden_activation: Hidden activation function.
encoder_hidden_sizes: List of the hidden sizes for the encoder.
encoder_output_dim: Dimension of the data outputted by the encoder.
mean_hidden_sizes: List of hidden sizes for mean head.
logvar_hidden_sizes: List of hidden sizes for logvar head.
"""
super().__init__()
self._learning_rate = learning_rate
self.encoder = MLP(
input_dim=input_dim,
output_dim=encoder_output_dim,
hidden_sizes=encoder_hidden_sizes,
hidden_activation=hidden_activation,
)
self.mean_head = MLP(
input_dim=encoder_output_dim,
output_dim=output_dim,
hidden_sizes=mean_hidden_sizes,
hidden_activation=hidden_activation,
)
self.logvar_head = MLP(
input_dim=encoder_output_dim,
output_dim=output_dim,
hidden_sizes=logvar_hidden_sizes,
hidden_activation=hidden_activation,
)
def get_mean_and_standard_deviation(
self,
x_data: np.ndarray,
device: str = "cpu",
) -> Tuple[np.ndarray, np.ndarray]:
"""Get the mean and standard deviation prediction.
Args:
x_data: The data in numpy ndarray form.
device: The device to use. Should be the same as the device
the model is currently on.
Returns:
Mean and standard deviation as ndarrays
"""
with torch.no_grad():
mean, logvar = self.forward(torch.Tensor(x_data, device=device))
mean = mean.numpy()
std = (logvar / 2).exp().numpy()
return mean, std
def forward(
self,
x_data: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the mean and standard deviation prediction.
Args:
x_data: The data in tensor form.
Returns:
Mean and log variance as tensors.
"""
latent = self.encoder(x_data)
return self.mean_head(latent), self.logvar_head(latent)
def training_step(
self,
batch: Tuple[torch.Tensor, torch.Tensor],
batch_idx: int,
) -> torch.Tensor:
"""Do a training step.
Args:
batch: The x and y data to train on.
batch_idx: Index of he batch.
Returns:
The loss.
"""
x_data, y_data = batch
mean, logvar = self.forward(x_data)
loss = torch.mean(self.compute_nll(mean, logvar, y_data))
self.log("train_loss", loss)
return loss
def validation_step(
self,
batch: Tuple[torch.Tensor, torch.Tensor],
batch_idx: int,
) -> None:
"""Do a validation step.
Args:
batch: The x and y data to train on.
batch_idx: Index of he batch.
"""
x_data, y_data = batch
mean, logvar = self.forward(x_data)
loss = torch.mean(self.compute_nll(mean, logvar, y_data))
self.log("validation_loss", loss)
def test_step(
self,
batch: Tuple[torch.Tensor, torch.Tensor],
batch_idx: int,
) -> None:
"""Do a validation step.
Args:
batch: The x and y data to train on.
batch_idx: Index of he batch.
"""
x_data, y_data = batch
mean, logvar = self.forward(x_data)
loss = torch.mean(self.compute_nll(mean, logvar, y_data))
self.log("test_loss", loss)
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure the optimizer.
Returns:
Optimizer
"""
return torch.optim.Adam(self.parameters(), lr=self._learning_rate)
def compute_nll(
self,
mean: torch.Tensor,
logvar: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""Compute the loss as negative log likelihood.
Args:
mean: The mean prediction for labels.
logvar: The logvariance prediction for labels.
labels: The observed labels of the data.
Returns:
The negative log likelihood of each point.
"""
sqdiffs = (mean - labels) ** 2
return torch.exp(-logvar) * sqdiffs + logvar