-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathLMLR.py
44 lines (34 loc) · 1.42 KB
/
LMLR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from lib.lorentz.manifold import CustomLorentz
class LorentzMLR(nn.Module):
""" Multinomial logistic regression (MLR) in the Lorentz model
"""
def __init__(
self,
manifold: CustomLorentz,
num_features: int,
num_classes: int
):
super(LorentzMLR, self).__init__()
self.manifold = manifold
self.a = torch.nn.Parameter(torch.zeros(num_classes,))
self.z = torch.nn.Parameter(F.pad(torch.zeros(num_classes, num_features-2), pad=(1,0), value=1)) # z should not be (0,0)
self.init_weights()
def forward(self, x):
# Hyperplane
sqrt_mK = 1/self.manifold.k.sqrt()
norm_z = torch.norm(self.z, dim=-1)
w_t = (torch.sinh(sqrt_mK*self.a)*norm_z)
w_s = torch.cosh(sqrt_mK*self.a.view(-1,1))*self.z
beta = torch.sqrt(-w_t**2+torch.norm(w_s, dim=-1)**2)
alpha = -w_t*x.narrow(-1, 0, 1) + (torch.cosh(sqrt_mK*self.a)*torch.inner(x.narrow(-1, 1, x.shape[-1]-1), self.z))
d = self.manifold.k.sqrt()*torch.abs(torch.asinh(sqrt_mK*alpha/beta)) # Distance to hyperplane
logits = torch.sign(alpha)*beta*d
return logits
def init_weights(self):
stdv = 1. / math.sqrt(self.z.size(1))
nn.init.uniform_(self.z, -stdv, stdv)
nn.init.uniform_(self.a, -stdv, stdv)