Skip to content

Commit

Permalink
resolved #308 (comment)
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Mar 23, 2024
1 parent 4991db6 commit 66edbb2
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions nobrainer/processing/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,17 @@ def _compile():
if info.get("normalizer") or normalizer:
dataset = dataset.normalize(normalizer)

n_epochs = info.get("epochs") or epochs
dataset = dataset.repeat(n_epochs).batch(batch_size)
steps_per_epoch = dataset.get_steps_per_epoch()

with self.strategy.scope():
# grow the networks by one (2^x) resolution
if resolution > self.current_resolution_:
self.model_.generator.add_resolution()
self.model_.discriminator.add_resolution()
_compile()

steps_per_epoch = dataset.get_steps_per_epoch()

# save_best_only is set to False as it is an adversarial loss
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
str(model_dir),
Expand All @@ -180,7 +182,7 @@ def _compile():

print("Transition phase")
self.model_.fit(
dataset,
dataset.dataset,
phase="transition",
resolution=resolution,
steps_per_epoch=steps_per_epoch, # necessary for repeat dataset
Expand All @@ -189,7 +191,7 @@ def _compile():

print("Resolution phase")
self.model_.fit(
dataset,
dataset.dataset,
phase="resolution",
resolution=resolution,
steps_per_epoch=steps_per_epoch,
Expand Down

0 comments on commit 66edbb2

Please sign in to comment.