Skip to content

Commit

Permalink
Use in preprocessing (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb authored Aug 30, 2023
1 parent 18101a9 commit 973d9af
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 19 deletions.
7 changes: 4 additions & 3 deletions keras_core/layers/preprocessing/random_brightness.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,18 @@ def call(self, inputs, training=True):
return inputs

def _randomly_adjust_brightness(self, images):
rank = len(images.shape)
images_shape = self.backend.shape(images)
rank = len(images_shape)
if rank == 3:
rgb_delta_shape = (1, 1, 1)
elif rank == 4:
# Keep only the batch dim. This will ensure to have same adjustment
# with in one image, but different across the images.
rgb_delta_shape = [self.backend.shape(images)[0], 1, 1, 1]
rgb_delta_shape = [images_shape[0], 1, 1, 1]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received "
f"inputs.shape={images.shape}"
f"inputs.shape={images_shape}"
)

seed_generator = self._get_seed_generator(self.backend._backend)
Expand Down
7 changes: 5 additions & 2 deletions keras_core/layers/preprocessing/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ def __init__(self, mode=HORIZONTAL_AND_VERTICAL, seed=None, **kwargs):
self._allow_non_tensor_positional_args = True

def _randomly_flip_inputs(self, inputs):
unbatched = len(inputs.shape) == 3
inputs_shape = self.backend.shape(inputs)
unbatched = len(inputs_shape) == 3
if unbatched:
inputs = self.backend.numpy.expand_dims(inputs, axis=0)
batch_size = self.backend.shape(inputs)[0]
inputs_shape = self.backend.shape(inputs)

batch_size = inputs_shape[0]
flipped_outputs = inputs
seed_generator = self._get_seed_generator(self.backend._backend)
if self.mode == HORIZONTAL or self.mode == HORIZONTAL_AND_VERTICAL:
Expand Down
2 changes: 2 additions & 0 deletions keras_core/layers/preprocessing/random_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def _get_rotation_matrix(self, inputs):
else:
image_height = shape[1]
image_width = shape[2]
image_height = float(image_height)
image_width = float(image_width)

lower = self._factor[0] * 2.0 * self.backend.convert_to_tensor(np.pi)
upper = self._factor[1] * 2.0 * self.backend.convert_to_tensor(np.pi)
Expand Down
14 changes: 8 additions & 6 deletions keras_core/layers/preprocessing/random_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,19 @@ def call(self, inputs, training=True):
return inputs

def _randomly_translate_inputs(self, inputs):
unbatched = len(inputs.shape) == 3
inputs_shape = self.backend.shape(inputs)
unbatched = len(inputs_shape) == 3
if unbatched:
inputs = self.backend.numpy.expand_dims(inputs, axis=0)
inputs_shape = self.backend.shape(inputs)

batch_size = self.backend.shape(inputs)[0]
batch_size = inputs_shape[0]
if self.data_format == "channels_first":
height = inputs.shape[-2]
width = inputs.shape[-1]
height = inputs_shape[-2]
width = inputs_shape[-1]
else:
height = inputs.shape[-3]
width = inputs.shape[-2]
height = inputs_shape[-3]
width = inputs_shape[-2]

seed_generator = self._get_seed_generator(self.backend._backend)
height_translate = self.backend.random.uniform(
Expand Down
18 changes: 10 additions & 8 deletions keras_core/layers/preprocessing/random_zoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,19 @@ def call(self, inputs, training=True):
return inputs

def _randomly_zoom_inputs(self, inputs):
unbatched = len(inputs.shape) == 3
inputs_shape = self.backend.shape(inputs)
unbatched = len(inputs_shape) == 3
if unbatched:
inputs = self.backend.numpy.expand_dims(inputs, axis=0)
inputs_shape = self.backend.shape(inputs)

batch_size = self.backend.shape(inputs)[0]
batch_size = inputs_shape[0]
if self.data_format == "channels_first":
height = inputs.shape[-2]
width = inputs.shape[-1]
height = inputs_shape[-2]
width = inputs_shape[-1]
else:
height = inputs.shape[-3]
width = inputs.shape[-2]
height = inputs_shape[-3]
width = inputs_shape[-2]

seed_generator = self._get_seed_generator(self.backend._backend)
height_zoom = self.backend.random.uniform(
Expand Down Expand Up @@ -225,8 +227,8 @@ def _get_zoom_matrix(self, zooms, image_height, image_width):
# [0 0 1]]
# where the last entry is implicit.
# zoom matrices are always float32.
x_offset = ((image_width - 1.0) / 2.0) * (1.0 - zooms[:, 0:1])
y_offset = ((image_height - 1.0) / 2.0) * (1.0 - zooms[:, 1:])
x_offset = ((float(image_width) - 1.0) / 2.0) * (1.0 - zooms[:, 0:1])
y_offset = ((float(image_height) - 1.0) / 2.0) * (1.0 - zooms[:, 1:])
return self.backend.numpy.concatenate(
[
zooms[:, 0:1],
Expand Down

0 comments on commit 973d9af

Please sign in to comment.