Skip to content

Commit

Permalink
modified: dsm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Apr 2, 2020
1 parent d47ce90 commit 4ae287b
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions dsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,26 @@ def __init__(self, inputdim, k, mlptyp=1, HIDDEN=None, init=False, dist='Weibull
self.k = k

self.mlptype = mlptyp
self.scale = nn.Parameter(-torch.ones(k))
self.shape = nn.Parameter(-torch.ones(k))



self.dist = dist

if self.dist == 'Weibull':

self.act = nn.SELU()

self.scale = nn.Parameter(-torch.ones(k))
self.shape = nn.Parameter(-torch.ones(k))


elif self.dist == 'LogNormal':

self.act = nn.Tanh()
self.scale = nn.Parameter(torch.ones(k))
self.shape = nn.Parameter(torch.ones(k))


self.HIDDEN = HIDDEN


Expand Down Expand Up @@ -96,15 +113,7 @@ def __init__(self, inputdim, k, mlptyp=1, HIDDEN=None, init=False, dist='Weibull
self.scale.data.fill_(init[1])


self.dist = dist

if self.dist == 'Weibull':

self.act = nn.SELU()

elif self.dist == 'LogNormal':

self.act = nn.Tanh()


def forward(self, x, adj=True):

Expand Down

0 comments on commit 4ae287b

Please sign in to comment.