-
Notifications
You must be signed in to change notification settings - Fork 17
/
hnn.py
102 lines (78 loc) · 3.63 KB
/
hnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# This file is from https://github.com/greydanus/hamiltonian-nn/blob/master/hnn.py
# For ablation study
import torch
import numpy as np
from nn_models import MLP
from utils import rk4
class HNN(torch.nn.Module):
'''Learn arbitrary vector fields that are sums of conservative and solenoidal fields'''
def __init__(self, input_dim, differentiable_model, field_type='solenoidal',
baseline=False, assume_canonical_coords=True):
super(HNN, self).__init__()
self.baseline = baseline
self.differentiable_model = differentiable_model
self.assume_canonical_coords = assume_canonical_coords
self.M = self.permutation_tensor(input_dim) # Levi-Civita permutation tensor
self.field_type = field_type
def forward(self, x):
# traditional forward pass
if self.baseline:
return self.differentiable_model(x)
y = self.differentiable_model(x)
assert y.dim() == 2 and y.shape[1] == 2, "Output tensor should have shape [batch_size, 2]"
return y.split(1,1)
def rk4_time_derivative(self, x, dt):
return rk4(fun=self.time_derivative, y0=x, t=0, dt=dt)
def time_derivative(self, x, t=None, separate_fields=False):
'''NEURAL ODE-STLE VECTOR FIELD'''
if self.baseline:
return self.differentiable_model(x)
'''NEURAL HAMILTONIAN-STLE VECTOR FIELD'''
F1, F2 = self.forward(x) # traditional forward pass
conservative_field = torch.zeros_like(x) # start out with both components set to 0
solenoidal_field = torch.zeros_like(x)
if self.field_type != 'solenoidal':
dF1 = torch.autograd.grad(F1.sum(), x, create_graph=True)[0] # gradients for conservative field
conservative_field = dF1 @ torch.eye(*self.M.shape)
if self.field_type != 'conservative':
dF2 = torch.autograd.grad(F2.sum(), x, create_graph=True)[0] # gradients for solenoidal field
solenoidal_field = dF2 @ self.M.t()
if separate_fields:
return [conservative_field, solenoidal_field]
return conservative_field + solenoidal_field
def int_wrapper(self, t, x):
return self.time_derivative(x)
def permutation_tensor(self,n):
M = None
if self.assume_canonical_coords:
M = torch.eye(n)
M = torch.cat([M[n//2:], -M[:n//2]])
else:
'''Constructs the Levi-Civita permutation tensor'''
M = torch.ones(n,n) # matrix of ones
M *= 1 - torch.eye(n) # clear diagonals
M[::2] *= -1 # pattern of signs
M[:,::2] *= -1
for i in range(n): # make asymmetric
for j in range(i+1, n):
M[i,j] *= -1
return M
class PixelHNN(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, autoencoder,
field_type='solenoidal', nonlinearity='tanh', baseline=False):
super(PixelHNN, self).__init__()
self.autoencoder = autoencoder
self.baseline = baseline
output_dim = input_dim if baseline else 2
nn_model = MLP(input_dim, hidden_dim, output_dim, nonlinearity)
self.hnn = HNN(input_dim, differentiable_model=nn_model, field_type=field_type, baseline=baseline)
def encode(self, x):
return self.autoencoder.encode(x)
def decode(self, z):
return self.autoencoder.decode(z)
def time_derivative(self, z, separate_fields=False):
return self.hnn.time_derivative(z, separate_fields)
def forward(self, x):
z = self.encode(x)
z_next = z + self.time_derivative(z)
return self.decode(z_next)