Skip to content

Commit

Permalink
modified: dcm_utilities.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Dec 24, 2021
1 parent b6ae902 commit 1f6b01a
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions dsm/contrib/dcm_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def partial_ll_loss(lrisks, tb, eb, eps=1e-2):
plls = lrisks - lrisksdenom
pll = plls[eb == 1]

pll = torch.sum(pll) # pll = tf.reduce_sum(pll)
pll = torch.sum(pll) # pll = tf.reduce_sum(pll)

return -pll

Expand Down Expand Up @@ -74,8 +74,8 @@ def get_survival(lrisks, breslow_splines, t):
return psurv

def get_posteriors(probs):
probs_ = probs+1e-8
return probs-torch.logsumexp(probs, dim=1).reshape(-1,1)
#probs_ = probs+1e-8
return probs-torch.logsumexp(probs, dim=1).reshape(-1,1)

def get_hard_z(gates_prob):
return torch.argmax(gates_prob, dim=1)
Expand All @@ -88,8 +88,8 @@ def repair_probs(probs):
probs[probs<-10] = -10
return probs

def get_likelihood(model, breslow_splines, x, t, e, log=False):
def get_likelihood(model, breslow_splines, x, t, e):

# Function requires numpy/torch

gates, lrisks = model(x)
Expand All @@ -109,7 +109,7 @@ def get_likelihood(model, breslow_splines, x, t, e, log=False):
return probs

def q_function(model, x, t, e, posteriors, typ='soft'):

if typ == 'hard': z = get_hard_z(posteriors)
else: z = sample_hard_z(posteriors)

Expand All @@ -119,7 +119,7 @@ def q_function(model, x, t, e, posteriors, typ='soft'):

loss = 0
for i in range(k):
lrisks_ = lrisks[z == i][:, i]
lrisks_ = lrisks[z == i][:, i]
loss += partial_ll_loss(lrisks_, t[z == i], e[z == i])

#log_smax_loss = -torch.nn.LogSoftmax(dim=1)(gates) # tf.nn.log_softmax(gates)
Expand All @@ -130,7 +130,7 @@ def q_function(model, x, t, e, posteriors, typ='soft'):

return loss

def e_step(model, breslow_splines, x, t, e, log=False):
def e_step(model, breslow_splines, x, t, e):

# TODO: Do this in `Log Space`
if breslow_splines is None:
Expand All @@ -147,13 +147,14 @@ def e_step(model, breslow_splines, x, t, e, log=False):
def m_step(model, optimizer, x, t, e, posteriors, typ='soft'):

optimizer.zero_grad()
loss = q_function(model, x, t, e, posteriors, typ)
loss = q_function(model, x, t, e, posteriors, typ)
loss.backward()
optimizer.step()

return float(loss)

def fit_breslow(model, x, t, e, posteriors=None, smoothing_factor=1e-4, typ='soft'):
def fit_breslow(model, x, t, e, posteriors=None,
smoothing_factor=1e-4, typ='soft'):

# TODO: Make Breslow in Torch !!!

Expand All @@ -173,14 +174,14 @@ def fit_breslow(model, x, t, e, posteriors=None, smoothing_factor=1e-4, typ='sof
breslow_splines = {}
for i in range(model.k):
breslowk = BreslowEstimator().fit(lrisks[:, i][z==i], e[z==i], t[z==i])
breslow_splines[i] = smooth_bl_survival(breslowk,
breslow_splines[i] = smooth_bl_survival(breslowk,
smoothing_factor=smoothing_factor)

return breslow_splines


def train_step(model, x, t, e, breslow_splines, optimizer,
bs=256, seed=100, typ='soft', use_posteriors=False,
bs=256, seed=100, typ='soft', use_posteriors=False,
update_splines_after=10, smoothing_factor=1e-4):

x, t, e = shuffle(x, t, e, random_state=seed)
Expand All @@ -203,7 +204,7 @@ def train_step(model, x, t, e, breslow_splines, optimizer,
posteriors = e_step(model, breslow_splines, xb, tb, eb)

torch.enable_grad()
loss = m_step(model, optimizer, xb, tb, eb, posteriors, typ=typ)
loss = m_step(model, optimizer, xb, tb, eb, posteriors, typ=typ)

with torch.no_grad():
try:
Expand Down

0 comments on commit 1f6b01a

Please sign in to comment.