Skip to content

Commit 27d6790

Browse files
committed
modified: dsm_torch.py
1 parent e759e6c commit 27d6790

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

dsm/dsm_torch.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -392,11 +392,10 @@ class DeepConvolutionalSurvivalMachinesTorch(DeepSurvivalMachinesTorch):
392392
Dimensionality of the input features. A tuple (height, width).
393393
k: int
394394
The number of underlying parametric distributions.
395+
embedding: torch.nn.Module
396+
A torch CNN to obtain the representation of the input data.
395397
hidden: int
396398
The number of neurons in each hidden layer.
397-
init: tuple
398-
A tuple for initialization of the parameters for the underlying
399-
distributions. (shape, scale).
400399
dist: str
401400
Choice of the underlying survival distributions.
402401
One of 'Weibull', 'LogNormal'.
@@ -411,8 +410,8 @@ class DeepConvolutionalSurvivalMachinesTorch(DeepSurvivalMachinesTorch):
411410
412411
"""
413412

414-
def __init__(self, inputdim, k, typ='ConvNet',
415-
hidden=None, dist='Weibull',
413+
def __init__(self, inputdim, k,
414+
embedding=None, hidden=None, dist='Weibull',
416415
temp=1000., discount=1.0, optimizer='Adam', risks=1):
417416
super(DeepSurvivalMachinesTorch, self).__init__()
418417

@@ -427,9 +426,13 @@ def __init__(self, inputdim, k, typ='ConvNet',
427426

428427
self._init_dsm_layers(hidden)
429428

430-
self.embedding = create_conv_representation(inputdim=inputdim,
431-
hidden=hidden,
432-
typ='ConvNet')
429+
if embedding is None:
430+
self.embedding = create_conv_representation(inputdim=inputdim,
431+
hidden=hidden,
432+
typ='ConvNet')
433+
else:
434+
self.embedding = embedding
435+
433436

434437
def forward(self, x, risk='1'):
435438
"""The forward function that is called when data is passed through DSM.

0 commit comments

Comments
 (0)