Skip to content

Commit

Permalink
UNet1D changes, validation F2=0.7724.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexklibisz committed Aug 10, 2017
1 parent a10c64e commit 45c0384
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions deepcalcium/models/spikes/unet_1d_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def on_epoch_end(self, epoch, logs):
dpi=120)


def unet1d(window_shape=(128,), nb_filters_base=32, conv_kernel_init='he_normal', prop_dropout_base=0.08, margin=2):
def unet1d(window_shape=(128,), nb_filters_base=32, conv_kernel_init='he_normal', prop_dropout_base=0.1, margin=4):
"""Builds and returns the UNet architecture using Keras.
# Arguments
window_shape: tuple of one integer defining the input/output window shape.
Expand Down Expand Up @@ -203,10 +203,9 @@ def __init__(self, cpdir='%s/spikes_unet1d' % CHECKPOINTS_DIR,
if not path.exists(self.cpdir):
mkdir(self.cpdir)

def fit(self, dataset_paths, shape=(4096,), error_margin=1.,
def fit(self, dataset_paths, shape=(4096,), error_margin=4.,
batch=20, nb_epochs=20, val_type='random_split', prop_trn=0.8,
prop_val=0.2, nb_folds=5, keras_callbacks=[],
optimizer=Adam(0.002)):
prop_val=0.2, nb_folds=5, keras_callbacks=[], optimizer=Adam(0.001)):
"""Constructs model based on parameters and trains with the given data.
Internally, the function uses a local function to abstract the training
for both validation types.
Expand Down Expand Up @@ -248,7 +247,7 @@ def _fit_single(idxs_trn, idxs_val, model_summary=False):
metrics = [F2, prec, reca, ytspks, ypspks]

def loss(yt, yp):
return weighted_binary_crossentropy(yt, yp, weightpos=3.0)
return weighted_binary_crossentropy(yt, yp, weightpos=2.0)
custom_objects = {o.__name__: o for o in metrics + [loss]}

# Define, compile network.
Expand Down

0 comments on commit 45c0384

Please sign in to comment.