From 45c0384e5c2ad7ab49e93626450897a12c089e7f Mon Sep 17 00:00:00 2001 From: Alex Klibisz Date: Thu, 10 Aug 2017 09:10:56 -0400 Subject: [PATCH] UNet1D changes, validation F2=0.7724. --- deepcalcium/models/spikes/unet_1d_segmentation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/deepcalcium/models/spikes/unet_1d_segmentation.py b/deepcalcium/models/spikes/unet_1d_segmentation.py index 7d6d50f..6246ce2 100644 --- a/deepcalcium/models/spikes/unet_1d_segmentation.py +++ b/deepcalcium/models/spikes/unet_1d_segmentation.py @@ -45,7 +45,7 @@ def on_epoch_end(self, epoch, logs): dpi=120) -def unet1d(window_shape=(128,), nb_filters_base=32, conv_kernel_init='he_normal', prop_dropout_base=0.08, margin=2): +def unet1d(window_shape=(128,), nb_filters_base=32, conv_kernel_init='he_normal', prop_dropout_base=0.1, margin=4): """Builds and returns the UNet architecture using Keras. # Arguments window_shape: tuple of one integer defining the input/output window shape. @@ -203,10 +203,9 @@ def __init__(self, cpdir='%s/spikes_unet1d' % CHECKPOINTS_DIR, if not path.exists(self.cpdir): mkdir(self.cpdir) - def fit(self, dataset_paths, shape=(4096,), error_margin=1., + def fit(self, dataset_paths, shape=(4096,), error_margin=4., batch=20, nb_epochs=20, val_type='random_split', prop_trn=0.8, - prop_val=0.2, nb_folds=5, keras_callbacks=[], - optimizer=Adam(0.002)): + prop_val=0.2, nb_folds=5, keras_callbacks=[], optimizer=Adam(0.001)): """Constructs model based on parameters and trains with the given data. Internally, the function uses a local function to abstract the training for both validation types. @@ -248,7 +247,7 @@ def _fit_single(idxs_trn, idxs_val, model_summary=False): metrics = [F2, prec, reca, ytspks, ypspks] def loss(yt, yp): - return weighted_binary_crossentropy(yt, yp, weightpos=3.0) + return weighted_binary_crossentropy(yt, yp, weightpos=2.0) custom_objects = {o.__name__: o for o in metrics + [loss]} # Define, compile network.