From 42f7d3de14a66f97ab5ba3ff2c6dee1adcf127cc Mon Sep 17 00:00:00 2001 From: Satrajit Ghosh Date: Wed, 3 Apr 2024 15:56:20 -0400 Subject: [PATCH 1/2] Update generation.py --- nobrainer/processing/generation.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index bc0da104..b729e68c 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -5,7 +5,7 @@ from .base import BaseEstimator from .. import losses -from ..dataset import get_dataset +from ..dataset import Dataset class ProgressiveGeneration(BaseEstimator): @@ -147,15 +147,17 @@ def _compile(): if batch_size % self.strategy.num_replicas_in_sync: raise ValueError("batch size must be a multiple of the number of GPUs") - dataset = get_dataset( + dataset = Dataset.from_tfrecords( file_pattern=info.get("file_pattern"), - batch_size=batch_size, num_parallel_calls=num_parallel_calls, volume_shape=(resolution, resolution, resolution), n_classes=1, - scalar_label=True, - normalizer=info.get("normalizer") or normalizer, + scalar_labels=True ) + n_epochs = info.get("epochs") or epochs + dataset.batch(batch_size) \ + .normalize(info.get("normalizer") or normalizer) \ + .repeat(n_epochs) with self.strategy.scope(): # grow the networks by one (2^x) resolution @@ -164,9 +166,7 @@ def _compile(): self.model_.discriminator.add_resolution() _compile() - steps_per_epoch = (info.get("epochs") or epochs) // info.get( - "batch_size" - ) + steps_per_epoch = n_epochs // info.get("batch_size") # save_best_only is set to False as it is an adversarial loss model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( @@ -182,7 +182,7 @@ def _compile(): print("Transition phase") self.model_.fit( - dataset, + dataset.dataset, phase="transition", resolution=resolution, steps_per_epoch=steps_per_epoch, # necessary for repeat dataset @@ -191,7 +191,7 @@ def _compile(): print("Resolution phase") self.model_.fit( - dataset, + dataset.dataset, phase="resolution", resolution=resolution, steps_per_epoch=steps_per_epoch, From 7a954e359a18441ba0bd58d6db5aaecaff0ae400 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Apr 2024 20:00:11 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nobrainer/processing/generation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index b729e68c..51b46625 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -152,12 +152,12 @@ def _compile(): num_parallel_calls=num_parallel_calls, volume_shape=(resolution, resolution, resolution), n_classes=1, - scalar_labels=True + scalar_labels=True, ) n_epochs = info.get("epochs") or epochs - dataset.batch(batch_size) \ - .normalize(info.get("normalizer") or normalizer) \ - .repeat(n_epochs) + dataset.batch(batch_size).normalize( + info.get("normalizer") or normalizer + ).repeat(n_epochs) with self.strategy.scope(): # grow the networks by one (2^x) resolution