forked from ysharma1126/ssl_identifiability
-
Notifications
You must be signed in to change notification settings - Fork 0
/
invertible_network_utils.py
110 lines (90 loc) · 4.09 KB
/
invertible_network_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Create invertible mixing networks."""
import numpy as np
import torch
from torch import nn
from scipy.stats import ortho_group
from typing import Union
from typing_extensions import Literal
import encoders
def construct_invertible_mlp(n: int = 20, n_layers: int = 2, n_iter_cond_thresh: int = 10000,
cond_thresh_ratio: float = 0.25,
weight_matrix_init: Union[Literal["pcl"], Literal["rvs"]] = 'pcl',
act_fct: Union[Literal["relu"], Literal["leaky_relu"], Literal["elu"],
Literal["smooth_leaky_relu"], Literal["softplus"]] = 'leaky_relu'
):
"""
Create an (approximately) invertible mixing network based on an MLP.
Based on the mixing code by Hyvarinen et al.
Args:
n: Dimensionality of the input and output data
n_layers: Number of layers in the MLP.
n_iter_cond_thresh: How many random matrices to use as a pool to find weights.
cond_thresh_ratio: Relative threshold how much the invertibility
(based on the condition number) can be violated in each layer.
weight_matrix_init: How to initialize the weight matrices.
act_fct: Activation function for hidden layers.
"""
class SmoothLeakyReLU(nn.Module):
def __init__(self, alpha=0.2):
super().__init__()
self.alpha = alpha
def forward(self, x):
return self.alpha * x + (1 - self.alpha) * torch.log(1 + torch.exp(x))
def get_act_fct(act_fct):
if act_fct == 'relu':
return torch.nn.ReLU, {}, 1
if act_fct == 'leaky_relu':
return torch.nn.LeakyReLU, {'negative_slope': 0.2}, 1
elif act_fct == 'elu':
return torch.nn.ELU, {'alpha': 1.0}, 1
elif act_fct == 'max_out':
raise NotImplemented()
elif act_fct == 'smooth_leaky_relu':
return SmoothLeakyReLU, {'alpha': 0.2}, 1
elif act_fct == 'softplus':
return torch.nn.Softplus, {'beta': 1}, 1
else:
raise Exception(f'activation function {act_fct} not defined.')
layers = []
act_fct, act_kwargs, act_fac = get_act_fct(act_fct)
# Subfuction to normalize mixing matrix
def l2_normalize(Amat, axis=0):
# axis: 0=column-normalization, 1=row-normalization
l2norm = np.sqrt(np.sum(Amat * Amat, axis))
Amat = Amat / l2norm
return Amat
condList = np.zeros([n_iter_cond_thresh])
if weight_matrix_init == 'pcl':
for i in range(n_iter_cond_thresh):
A = np.random.uniform(-1, 1, [n, n])
A = l2_normalize(A, axis=0)
condList[i] = np.linalg.cond(A)
condList.sort() # Ascending order
condThresh = condList[int(n_iter_cond_thresh * cond_thresh_ratio)]
print("condition number threshold: {0:f}".format(condThresh))
for i in range(n_layers):
lin_layer = nn.Linear(n, n, bias=False)
if weight_matrix_init == 'pcl':
condA = condThresh + 1
while condA > condThresh:
weight_matrix = np.random.uniform(-1, 1, (n, n))
weight_matrix = l2_normalize(weight_matrix, axis=0)
condA = np.linalg.cond(weight_matrix)
# print(" L{0:d}: cond={1:f}".format(i, condA))
print(f"layer {i+1}/{n_layers}, condition number: {np.linalg.cond(weight_matrix)}")
lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32)
elif weight_matrix_init == 'rvs':
weight_matrix = ortho_group.rvs(n)
lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32)
elif weight_matrix_init == 'expand':
pass
else:
raise Exception(f'weight matrix {weight_matrix_init} not implemented')
layers.append(lin_layer)
if i < n_layers - 1:
layers.append(act_fct(**act_kwargs))
mixing_net = nn.Sequential(*layers)
# fix parameters
for p in mixing_net.parameters():
p.requires_grad = False
return mixing_net