@@ -392,11 +392,10 @@ class DeepConvolutionalSurvivalMachinesTorch(DeepSurvivalMachinesTorch):
392
392
Dimensionality of the input features. A tuple (height, width).
393
393
k: int
394
394
The number of underlying parametric distributions.
395
+ embedding: torch.nn.Module
396
+ A torch CNN to obtain the representation of the input data.
395
397
hidden: int
396
398
The number of neurons in each hidden layer.
397
- init: tuple
398
- A tuple for initialization of the parameters for the underlying
399
- distributions. (shape, scale).
400
399
dist: str
401
400
Choice of the underlying survival distributions.
402
401
One of 'Weibull', 'LogNormal'.
@@ -411,8 +410,8 @@ class DeepConvolutionalSurvivalMachinesTorch(DeepSurvivalMachinesTorch):
411
410
412
411
"""
413
412
414
- def __init__ (self , inputdim , k , typ = 'ConvNet' ,
415
- hidden = None , dist = 'Weibull' ,
413
+ def __init__ (self , inputdim , k ,
414
+ embedding = None , hidden = None , dist = 'Weibull' ,
416
415
temp = 1000. , discount = 1.0 , optimizer = 'Adam' , risks = 1 ):
417
416
super (DeepSurvivalMachinesTorch , self ).__init__ ()
418
417
@@ -427,9 +426,13 @@ def __init__(self, inputdim, k, typ='ConvNet',
427
426
428
427
self ._init_dsm_layers (hidden )
429
428
430
- self .embedding = create_conv_representation (inputdim = inputdim ,
431
- hidden = hidden ,
432
- typ = 'ConvNet' )
429
+ if embedding is None :
430
+ self .embedding = create_conv_representation (inputdim = inputdim ,
431
+ hidden = hidden ,
432
+ typ = 'ConvNet' )
433
+ else :
434
+ self .embedding = embedding
435
+
433
436
434
437
def forward (self , x , risk = '1' ):
435
438
"""The forward function that is called when data is passed through DSM.
0 commit comments