Skip to content

Commit

Permalink
move random_cutout_mask_area ahead of random_process_func
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Jan 11, 2022
1 parent befd497 commit 95b3d03
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
20 changes: 10 additions & 10 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,25 +274,25 @@ def prepare_dataset(

ds = ds.map(process_func, num_parallel_calls=AUTOTUNE)

if is_train and random_status >= 0:
random_process_image = RandomProcessImage(img_shape, random_status, random_crop)
random_process_func = lambda xx, yy: (random_process_image.process(xx), yy)
ds = ds.map(random_process_func, num_parallel_calls=AUTOTUNE)

ds = ds.batch(batch_size, drop_remainder=True) # Use batch --> map has slightly effect on dataset reading time, but harm the randomness
if random_cutout_mask_area > 0:
print(">>>> random_cutout_mask_area provided:", random_cutout_mask_area)
mask_height = img_shape[0] * 2 // 5
mask_func = lambda images, labels: (
mask_func = lambda imm, label: (
tf.cond(
tf.random.uniform(()) < random_cutout_mask_area,
lambda: tf.concat([images[:, :-mask_height], tf.zeros_like(images[:, -mask_height:]) + 128], axis=1),
lambda: images,
lambda: tf.concat([imm[:-mask_height], tf.zeros_like(imm[-mask_height:]) + 128], axis=0),
lambda: imm,
),
labels,
label,
)
ds = ds.map(mask_func, num_parallel_calls=AUTOTUNE)

if is_train and random_status >= 0:
random_process_image = RandomProcessImage(img_shape, random_status, random_crop)
random_process_func = lambda xx, yy: (random_process_image.process(xx), yy)
ds = ds.map(random_process_func, num_parallel_calls=AUTOTUNE)

ds = ds.batch(batch_size, drop_remainder=True) # Use batch --> map has slightly effect on dataset reading time, but harm the randomness
if mixup_alpha > 0 and mixup_alpha <= 1:
print(">>>> mixup_alpha provided:", mixup_alpha)
ds = ds.map(lambda xx, yy: mixup((xx - 127.5) * 0.0078125, yy, alpha=mixup_alpha))
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,4 @@ def train(self, train_schedule, initial_epoch=0):
if self.model is None or self.model.stop_training == True:
print(">>>> But it's an early stop, break...")
break
return initial_epoch

0 comments on commit 95b3d03

Please sign in to comment.