Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PGAN issues #308

Open
hvgazula opened this issue Mar 23, 2024 · 7 comments
Open

Fix PGAN issues #308

hvgazula opened this issue Mar 23, 2024 · 7 comments
Assignees
Milestone

Comments

@hvgazula
Copy link
Contributor

hvgazula commented Mar 23, 2024

From @satra :

  • There is an issue with loading checkpoints (in multi-GPU case).

Notes (03/22/2024)

  • could not run the example on dgx100, so moving it to CPU for testing.
  • 4 CPUs, 96 GB: was able to write shard for different resolutions

from ..dataset import get_dataset

This is a relic from a previous iteration. This means replacing the following snippet as well

dataset = get_dataset(
file_pattern=info.get("file_pattern"),
batch_size=batch_size,
num_parallel_calls=num_parallel_calls,
volume_shape=(resolution, resolution, resolution),
n_classes=1,
scalar_label=True,
normalizer=info.get("normalizer") or normalizer,
)

@hvgazula hvgazula self-assigned this Mar 23, 2024
hvgazula added a commit that referenced this issue Mar 23, 2024
@hvgazula
Copy link
Contributor Author

running the code in the PGAN notebook with the aforementioned fix throws up the following error

image

@hvgazula hvgazula changed the title PGAN: get_dataset doesn't exist in dataset.py Fix PGAN issues Mar 23, 2024
@hvgazula
Copy link
Contributor Author

Note: Running the brain_extraction.ipynb turns up DatasetAdapter at https://github.com/tensorflow/tensorflow/blob/51871ec0c5d2925cbbf7aa539087ac51ea27892e/tensorflow/python/keras/engine/data_adapter.py#L987 and type(x) returns python.data.ops.dataset_ops.BatchDataset

Adding dataset.batch(1) before the call to fit in generation.py still returns the same error.

@satra
Copy link
Contributor

satra commented Mar 23, 2024

this should get you past the dataset issues and the notebook appears to train.

diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py
--- a/nobrainer/processing/generation.py
+++ b/nobrainer/processing/generation.py
@@ -5,7 +5,7 @@ import tensorflow as tf
 
 from .base import BaseEstimator
 from .. import losses
-from ..dataset import get_dataset
+from ..dataset import Dataset
 
 
 class ProgressiveGeneration(BaseEstimator):
@@ -147,15 +147,17 @@ class ProgressiveGeneration(BaseEstimator):
             if batch_size % self.strategy.num_replicas_in_sync:
                 raise ValueError("batch size must be a multiple of the number of GPUs")
 
-            dataset = get_dataset(
+            dataset = Dataset.from_tfrecords(
                 file_pattern=info.get("file_pattern"),
-                batch_size=batch_size,
                 num_parallel_calls=num_parallel_calls,
                 volume_shape=(resolution, resolution, resolution),
                 n_classes=1,
-                scalar_label=True,
-                normalizer=info.get("normalizer") or normalizer,
+                scalar_labels=True
             )
+            n_epochs = info.get("epochs") or epochs
+            dataset.batch(batch_size) \
+            .normalize(info.get("normalizer") or normalizer) \
+            .repeat(n_epochs)
 
             with self.strategy.scope():
                 # grow the networks by one (2^x) resolution
@@ -164,9 +166,7 @@ class ProgressiveGeneration(BaseEstimator):
                     self.model_.discriminator.add_resolution()
                 _compile()
 
-                steps_per_epoch = (info.get("epochs") or epochs) // info.get(
-                    "batch_size"
-                )
+                steps_per_epoch = n_epochs // info.get("batch_size")
 
                 # save_best_only is set to False as it is an adversarial loss
                 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
@@ -182,7 +182,7 @@ class ProgressiveGeneration(BaseEstimator):
 
             print("Transition phase")
             self.model_.fit(
-                dataset,
+                dataset.dataset,
                 phase="transition",
                 resolution=resolution,
                 steps_per_epoch=steps_per_epoch,  # necessary for repeat dataset
@@ -191,7 +191,7 @@ class ProgressiveGeneration(BaseEstimator):
 
             print("Resolution phase")
             self.model_.fit(
-                dataset,
+                dataset.dataset,
                 phase="resolution",
                 resolution=resolution,
                 steps_per_epoch=steps_per_epoch,

@satra
Copy link
Contributor

satra commented Mar 23, 2024

added a few updates above and generation notebook completes.

@hvgazula hvgazula added this to the 1.2.1 milestone Mar 23, 2024
@hvgazula
Copy link
Contributor Author

Thank you very much @satra. dataset.dataset it is.

@hvgazula
Copy link
Contributor Author

steps_per_epoch = (info.get("epochs") or epochs) // info.get(
"batch_size"
)

steps_per_epoch should be training_size // batch_size and not as noted. Also, there is no need to calculate this explicitly. The fit function will take care of it in the default case. See this

hvgazula added a commit that referenced this issue Mar 23, 2024
@hvgazula
Copy link
Contributor Author

TODO: warm start

@satra satra mentioned this issue Apr 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants