Skip to content

Commit f86d654

Browse files
author
Chirag Nagpal
authored
Update losses.py
1 parent 01bb05c commit f86d654

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

auton_survival/models/dsm/losses.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def _weibull_pdf(model, x, t_horizon, risk='1'):
324324
lpdfs = torch.stack(lpdfs, dim=1)
325325
lpdfs = lpdfs+logits
326326
lpdfs = torch.logsumexp(lpdfs, dim=1)
327-
pdfs.append(lpdfs.detach().numpy())
327+
pdfs.append(lpdfs.detach().cpu().numpy())
328328

329329
return pdfs
330330

@@ -357,7 +357,7 @@ def _weibull_cdf(model, x, t_horizon, risk='1'):
357357
lcdfs = torch.stack(lcdfs, dim=1)
358358
lcdfs = lcdfs+logits
359359
lcdfs = torch.logsumexp(lcdfs, dim=1)
360-
cdfs.append(lcdfs.detach().numpy())
360+
cdfs.append(lcdfs.detach().cpu().numpy())
361361

362362
return cdfs
363363

@@ -424,7 +424,7 @@ def _lognormal_cdf(model, x, t_horizon, risk='1'):
424424
lcdfs = torch.stack(lcdfs, dim=1)
425425
lcdfs = lcdfs+logits
426426
lcdfs = torch.logsumexp(lcdfs, dim=1)
427-
cdfs.append(lcdfs.detach().numpy())
427+
cdfs.append(lcdfs.detach().cpu().numpy())
428428

429429
return cdfs
430430

@@ -461,7 +461,7 @@ def _normal_cdf(model, x, t_horizon, risk='1'):
461461
lcdfs = torch.stack(lcdfs, dim=1)
462462
lcdfs = lcdfs+logits
463463
lcdfs = torch.logsumexp(lcdfs, dim=1)
464-
cdfs.append(lcdfs.detach().numpy())
464+
cdfs.append(lcdfs.detach().cpu().numpy())
465465

466466
return cdfs
467467

@@ -485,7 +485,7 @@ def _normal_mean(model, x, risk='1'):
485485
lmeans = lmeans*logits
486486
lmeans = torch.sum(lmeans, dim=1)
487487

488-
return lmeans.detach().numpy()
488+
return lmeans.detach().cpu().numpy()
489489

490490

491491
def predict_mean(model, x, risk='1'):

0 commit comments

Comments
 (0)