diff --git a/dsm/utilities.py b/dsm/utilities.py index bc25cca..486e12c 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -57,7 +57,7 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, risks=model.risks) premodel.double() - optimizer = get_optimizer(model, lr) + optimizer = get_optimizer(premodel, lr) oldcost = float('inf') patience = 0