Skip to content

Commit

Permalink
modified: dsm/dsm_api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Jan 26, 2021
1 parent cc3399a commit 5252d6f
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,22 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
t_train = torch.from_numpy(t_train).double()
e_train = torch.from_numpy(e_train).double()

vsize = int(vsize*x_train.shape[0])
if vsize is not None:

x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:]
x_train = x_train[:-vsize]
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]
vsize = int(vsize*x_train.shape[0])

x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:]
x_train = x_train[:-vsize]
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]

return (x_train, t_train, e_train,
x_val, t_val, e_val)

else:
return (x_train, t_train, e_train,
x_train, t_train, e_train)

return (x_train, t_train, e_train,
x_val, t_val, e_val)

def predict_mean(self, x, risk=1):
r"""Returns the mean Time-to-Event \( t \)
Expand Down Expand Up @@ -359,15 +366,20 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
t_train = torch.from_numpy(t_train).double()
e_train = torch.from_numpy(e_train).double()

vsize = int(vsize*x_train.shape[0])
x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:]
if vsize is not None:

vsize = int(vsize*x_train.shape[0])
x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:]

x_train = x_train[:-vsize]
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]
x_train = x_train[:-vsize]
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]

return (x_train, t_train, e_train,
x_val, t_val, e_val)
return (x_train, t_train, e_train,
x_val, t_val, e_val)
else:
return (x_train, t_train, e_train,
x_train, t_train, e_train)


class DeepConvolutionalSurvivalMachines(DSMBase):
Expand Down

0 comments on commit 5252d6f

Please sign in to comment.