Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce DropBlock2D regularization layer. #166

Merged
merged 10 commits into from
Mar 18, 2022

Conversation

sebastian-sz
Copy link
Contributor

DropBlock2d is a regularization layer (similar to Dropout), more suitable for Convolutional networks.

Linked Issue: #137

This is mostly a Keras Layer wrapper around original implementation.

@LukeWood
Copy link
Contributor

LukeWood commented Mar 7, 2022

Thanks for the PR! Going forward we will be using the BaseImageAugmentationLayer. I'll have an example ready for that shortly... had to finish the base layer.

keras-team/keras#16161

@sebastian-sz
Copy link
Contributor Author

sebastian-sz commented Mar 8, 2022

@LukeWood Can we discuss it? I did not want to use BaseImageAugmentationLayer because DropBlock2D is not performing image augmentation - it's not meant to be used before model for modyfing input, but inside model to modify the feature maps.

Also the functionalities of BaseImageAugmentationLayer, like vectorized map, access to label modifications etc. are not necessary here, as the above implementation is vectorized and only cares about output from the previous layer.
I think it is much closer to Dropout, than to other preprocessing layers.

I do not want to break layer consistency in this repo, but as I said I'm not sure if BaseImageAugmentationLayer is necessary here?

@bhack
Copy link
Contributor

bhack commented Mar 8, 2022

I don't know if It makes sense to inherit from base_layer.BaseRandomLayer https://github.com/keras-team/keras/blob/master/keras/layers/regularization/dropout.py#L26

@sebastian-sz
Copy link
Contributor Author

@bhack would you also say that the layer should inherit from BaseImageAugmentationLayer?

@bhack
Copy link
Contributor

bhack commented Mar 8, 2022

I meant that we could evaluate if this layer could inherit from BaseRandomLayer as the linked dropout.

@LukeWood
Copy link
Contributor

LukeWood commented Mar 8, 2022

@LukeWood Can we discuss it? I did not want to use BaseImageAugmentationLayer because DropBlock2D is not performing image augmentation - it's not meant to be used before model for modyfing input, but inside model to modify the feature maps.

Also the functionalities of BaseImageAugmentationLayer, like vectorized map, access to label modifications etc. are not necessary here, as the above implementation is vectorized and only cares about output from the previous layer.
I think it is much closer to Dropout, than to other preprocessing layers.

I do not want to break layer consistency in this repo, but as I said I'm not sure if BaseImageAugmentationLayer is necessary here?

Ah I see I did not realize how this was meant to be used. Let me take a look and review then.

@LukeWood LukeWood self-requested a review March 8, 2022 21:39
@sebastian-sz
Copy link
Contributor Author

@LukeWood thanks!

Copy link
Contributor

@LukeWood LukeWood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far looking pretty good, some minor comments but overall great layer!

)

Args:
keep_probability: float. Probability of keeping a unit. Defaults to 0.9.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to make this a tf.variable internally so it doesn't get baked into the graph to support a keep_probability schedule callback. Not sure yet, no action item needed on your side. Still doing some digging here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was also thinking about the Schedule Callback. Something like

class DropBlockScheduleCallback(tf.keras.callbacks.Callback):
    def __init__(self, total_steps: int):
        super(DropBlockScheduleCallback, self).__init__()
        self._total_steps = total_steps
        self._current_step = 0

    def on_train_batch_end(self):
        self._current_step += 1
        for layer in self.model.layers:
            if isinstance(layer, DropBlock2D):
                new_keep_probability = self._calc_new_keep_probability(
                    layer._keep_probability  # TODO: public
                )
                layer._keep_probability.assign(new_keep_probability)

    def _calc_new_keep_probability(self, current_keep_probability):
        current_ratio = self._current_step / self._total_steps
        return 1 - current_ratio * (1 - current_keep_probability)

Although I'm not sure:

  1. is there any other way to get total_steps without having them need to be explicitly provided.
  2. is iterating over all model layers for each batch not too slow. Maybe they could be saved on_train_begin somewhere and used as a reference on_batch_end?

If I'm thinking in the right direction, I could add this in this PR. Or this could be a separate issue. Overall, please let me know.

Copy link

@lucasdavid lucasdavid Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is easy to implement, but it would be very cool if could accept lambdas or Schedules:

global_step = tf.Variable(0, trainable=False)
db = DropBlock(keep_probability=lambda: 1/step)
# or ...
first_keep_prob = 1.0
db = DropBlock(
  keep_probability=tf.keras.optimizers.schedules.InverseTimeDecay(
    first_keep_prob,
    decay_steps=steps_per_epoch*epochs,
    decay_rate=0.96))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm not sure what would be optimal solution for this parameter's scheduling. I would probably add it in a separate PR (+ separate issue), as there is still some work regarding the basic functionality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll discuss this with the team and let you know. It would be good to get this sorted out because we actually will have the same problem with a few other components.

@sebastian-sz sebastian-sz force-pushed the feature-137/add-dropblock branch from 9cfea20 to 7620851 Compare March 10, 2022 19:40
@sebastian-sz sebastian-sz requested a review from LukeWood March 10, 2022 19:57
@sebastian-sz
Copy link
Contributor Author

@qlzh727 @LukeWood
There is an issue, where if dropblock_size (height or width) is equal or bigger than input feature map, the layer will output nan values.
What is your take on validating the input size in the call method?

def call(...):
(...)
    tf.Assert(tf.less(self._dropblock_height, height), data=[''])
    tf.Assert(tf.less(self._dropblock_width, width), data=[''])

@sebastian-sz sebastian-sz requested a review from qlzh727 March 16, 2022 17:51
@qlzh727
Copy link
Member

qlzh727 commented Mar 16, 2022

@qlzh727 @LukeWood There is an issue, where if dropblock_size (height or width) is equal or bigger than input feature map, the layer will output nan values. What is your take on validating the input size in the call method?

def call(...):
(...)
    tf.Assert(tf.less(self._dropblock_height, height), data=[''])
    tf.Assert(tf.less(self._dropblock_width, width), data=[''])

I think we should add such validation logic so that the mask block size is never bigger than the actual input.

@bhack
Copy link
Contributor

bhack commented Mar 16, 2022

I think we should add such validation logic so that the mask block size is never bigger than the actual input.

Isn't an user directed controlled parameter?
In other cases we are not protecting a misusing when we think that it is a caller ownership like in #144 (comment)

Copy link
Member

@qlzh727 qlzh727 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update.

@sebastian-sz sebastian-sz requested a review from qlzh727 March 17, 2022 06:56
Copy link
Member

@qlzh727 qlzh727 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR.

@LukeWood
Copy link
Contributor

LukeWood commented Mar 17, 2022

Alright! This looks pretty great @sebastian-sz. Let me discuss the scheduling issue and then we can merge!

@LukeWood LukeWood merged commit e8769bb into keras-team:master Mar 18, 2022
@lucasdavid
Copy link

lucasdavid commented Mar 24, 2022

Saving a model that contains this layer and re-loading it into memory is raising an "unexpected argument trainable" error. Snippet:

import tensorflow as tf
from keras import layers
model = tf.keras.Sequential([
  layers.Conv2D(32, 3)
  DropBlock2D(0.1, 16)
])
model.save('saved_model')
tf.keras.models.load_model('saved_model')

Adding **kwargs in the layer's initializer and passing it to the superclass fixes it:

@tf.keras.utils.register_keras_serializable('layers')
class DropBlock2D(BaseRandomLayer):
  def __init__(
      self,
      dropout_rate,
      dropblock_size,
      data_format=None,
      seed=None,
      name=None,
      **kwargs,
  ):
    super().__init__(seed=seed, name=name, force_generator=True, **kwargs)

@bhack
Copy link
Contributor

bhack commented Mar 24, 2022

Yes, thanks I think that more in general we need to add serialization tests at some point

@sebastian-sz sebastian-sz deleted the feature-137/add-dropblock branch March 30, 2022 08:59
ianstenbit pushed a commit to ianstenbit/keras-cv that referenced this pull request Aug 6, 2022
* Added DropBlock2D regularization layer.

* Forced DropBlock2D to use generator. Fixed XLA test.

* Dropblock2D PR fixes.

* Added copyright to conv_utils.py

* Dropblock2D: changed keep_probability to dropout_rate.

* Allow non-square dropblock_size in DropBlock2D.

* Refactored DropBlock2D to use tf.shape.

* Changed tf.debugging to tf.test.TestCase assertions.

* Renamed seed_drop_rate to gamma.

* Expanded DropBlock2D's docstring.
adhadse pushed a commit to adhadse/keras-cv that referenced this pull request Sep 17, 2022
* Added DropBlock2D regularization layer.

* Forced DropBlock2D to use generator. Fixed XLA test.

* Dropblock2D PR fixes.

* Added copyright to conv_utils.py

* Dropblock2D: changed keep_probability to dropout_rate.

* Allow non-square dropblock_size in DropBlock2D.

* Refactored DropBlock2D to use tf.shape.

* Changed tf.debugging to tf.test.TestCase assertions.

* Renamed seed_drop_rate to gamma.

* Expanded DropBlock2D's docstring.
freedomtan pushed a commit to freedomtan/keras-cv that referenced this pull request Jul 20, 2023
* first version of torch data adapter

* fix missing attribute

* 1. Rename TorchDatasetAdapter to TorchDataLoaderAdapter
2. Replace dataset with dataloader
3. Add checks for partial batch size for tensordataset
4. Do inline import

* add tests for torch data adapter

* remove class_weight argument

* add torch to the package list

* make sure torch is installed for CPU only

* update link for torch cpu install

* fix torch cpu install download link

* fix bug in partial size calculation

* update tests

* add get_tf_dataset method

* fix link for torch

* try index link flag

* fix torch whl link

* cpu only install

* add torch cpu

* fix requirements.txt

* add note

* inline torch import and modify tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants