Skip to content

Commit

Permalink
add random_cutout_or_cutout_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Jan 11, 2022
1 parent 9c26e05 commit dae425d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
72 changes: 52 additions & 20 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def tf_imread(file_path):


class RandomProcessImage:
def __init__(self, img_shape=(112, 112), random_status=2, random_crop=None):
def __init__(self, img_shape=(112, 112), random_status=2, random_crop=None, random_cutout_mask_area=0):
self.img_shape, self.random_status, self.random_crop = img_shape[:2], random_status, random_crop
if random_status >= 100:
magnitude = 5 * random_status / 100
Expand All @@ -78,10 +78,25 @@ def __init__(self, img_shape=(112, 112), random_status=2, random_crop=None):
import augment

aa = augment.RandAugment(magnitude=magnitude, cutout_const=40)
aa.available_ops = ["AutoContrast", "Equalize", "Color", "Contrast", "Brightness", "Sharpness", "Cutout"]
self.process = lambda img: aa.distort(tf.image.random_flip_left_right(img))
if random_cutout_mask_area > 0:
print(">>>> random_cutout_mask_area provided:", random_cutout_mask_area)
# aa.available_ops = ["AutoContrast", "Equalize", "Color", "Contrast", "Brightness", "Sharpness"]
# random_cutout = 1 / len(aa.available_ops)
# self.process = lambda img: aa.distort(
# random_cutout_or_cutout_mask(tf.image.random_flip_left_right(img), img_shape, random_cutout_mask_area, random_cutout)
# )
aa.available_ops = ["AutoContrast", "Equalize", "Color", "Contrast", "Brightness", "Sharpness", "Cutout"]
self.process = lambda img: aa.distort(
random_cutout_or_cutout_mask(tf.image.random_flip_left_right(img), img_shape, random_cutout_mask_area, random_cutout=0)
)
else:
aa.available_ops = ["AutoContrast", "Equalize", "Color", "Contrast", "Brightness", "Sharpness", "Cutout"]
self.process = lambda img: aa.distort(tf.image.random_flip_left_right(img))
else:
self.process = lambda img: self.tf_buildin_image_random(img)
if random_cutout_mask_area > 0:
self.process = lambda img: self.tf_buildin_image_random(random_cutout_or_cutout_mask(img, random_cutout_mask_area, random_cutout=0))
else:
self.process = lambda img: self.tf_buildin_image_random(img)

def tf_buildin_image_random(self, img):
if self.random_status >= 0:
Expand All @@ -102,6 +117,38 @@ def tf_buildin_image_random(self, img):
return img


def random_cutout_or_cutout_mask(image, image_shape, random_cutout_mask_area=0.3, random_cutout=0, pad_size=20, replace=128):
from augment import cutout

# image_hh, image_ww = image.shape[:2]
image_hh, image_ww = image_shape[:2]
# mask_height = img_shape[0] * 3 // 5
min_hh, max_hh = int(float(image_hh) * 0.55), int(float(image_ww) * 0.7)
random_height = lambda: tf.random.uniform((), min_hh, max_hh, dtype=tf.int32)

cutout_func = lambda imm: tf.cond(
tf.random.uniform(()) < random_cutout,
lambda: cutout(imm, pad_size=pad_size, replace=replace),
lambda: imm,
)

if random_cutout > 0:
mask_func = lambda imm: tf.cond(
tf.random.uniform(()) < random_cutout_mask_area,
# lambda: tf.concat([imm[:mask_height], tf.zeros_like(imm[mask_height:]) + 128], axis=0),
lambda: tf.image.pad_to_bounding_box(imm[: random_height()] - replace, 0, 0, image_hh, image_ww) + replace,
lambda: cutout_func(cutout_func(imm)), # randaug num_layers=2
)
else:
mask_func = lambda imm: tf.cond(
tf.random.uniform(()) < random_cutout_mask_area,
# lambda: tf.concat([imm[:mask_height], tf.zeros_like(imm[mask_height:]) + 128], axis=0),
lambda: tf.image.pad_to_bounding_box(imm[: random_height()] - replace, 0, 0, image_hh, image_ww) + replace,
lambda: imm,
)
return mask_func(image)


def sample_beta_distribution(size, concentration_0=0.4, concentration_1=0.4):
gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1)
gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0)
Expand Down Expand Up @@ -274,23 +321,8 @@ def prepare_dataset(

ds = ds.map(process_func, num_parallel_calls=AUTOTUNE)

if random_cutout_mask_area > 0:
print(">>>> random_cutout_mask_area provided:", random_cutout_mask_area)
# mask_height = img_shape[0] * 2 // 5
random_height = lambda: tf.random.uniform((), int(img_shape[0] * 0.55), int(img_shape[0] * 0.7), dtype=tf.int32)
mask_func = lambda imm, label: (
tf.cond(
tf.random.uniform(()) < random_cutout_mask_area,
# lambda: tf.concat([imm[:-mask_height], tf.zeros_like(imm[-mask_height:]) + 128], axis=0),
lambda: tf.image.pad_to_bounding_box(imm[:random_height()] - 128, 0, 0, img_shape[0], img_shape[1]) + 128,
lambda: imm,
),
label,
)
ds = ds.map(mask_func, num_parallel_calls=AUTOTUNE)

if is_train and random_status >= 0:
random_process_image = RandomProcessImage(img_shape, random_status, random_crop)
random_process_image = RandomProcessImage(img_shape, random_status, random_crop, random_cutout_mask_area)
random_process_func = lambda xx, yy: (random_process_image.process(xx), yy)
ds = ds.map(random_process_func, num_parallel_calls=AUTOTUNE)

Expand Down
1 change: 1 addition & 0 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def call(self, y_true, norm_logits):
# [MagFace: A Universal Representation for Face Recognition and Quality Assessment](https://arxiv.org/pdf/2103.06627.pdf)
@keras.utils.register_keras_serializable(package="keras_insightface")
class MagFaceLoss(ArcfaceLossSimple):
""" Another set for fine-tune is: min_feature_norm, max_feature_norm, min_margin, max_margin, regularizer_loss_lambda = 1, 51, 0.45, 1, 5 """
def __init__(
self,
min_feature_norm=10.0,
Expand Down

0 comments on commit dae425d

Please sign in to comment.