Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segformer swi #1292

Closed
wants to merge 23 commits into from
Closed

Segformer swi #1292

wants to merge 23 commits into from

Conversation

Yael-Baron
Copy link
Contributor

No description provided.

Copy link
Contributor

@Louis-Dupont Louis-Dupont left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the PR for review ?
I had a look anyway and left some comments, but just minor notes on naming/doc

Copy link
Contributor

@shaydeci shaydeci left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments inline, about some we have already had a discussion (so just a reminder).
Still missing:

  • Unit tests - see inline.
  • Integration tests, pretrained model urls.
  • New phase for the final valdiation on the average model.

@@ -148,6 +148,7 @@ class Callbacks:
DEKR_VISUALIZATION = "DEKRVisualizationCallback"
ROBOFLOW_RESULT_CALLBACK = "RoboflowResultCallback"
TIMER = "TimerCallback"
SLIDING_WINDOW_INFERENCE = "ChangeToSWI"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the name of the callback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed on c79d0f2

@@ -856,6 +856,22 @@ def _infer_global_step(self, context: PhaseContext, is_train_loader: bool):
return total_steps_in_done + train_loader_length + context.batch_idx


@register_callback(Callbacks.SLIDING_WINDOW_INFERENCE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why choose an inconsistent name? This just adds more confusion....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed on c79d0f2

def on_validation_loader_start(self, context: PhaseContext) -> None:
if context.training_params.max_epochs - 1 == context.epoch:
unwrap_model(context.net).enable_swi()
context.valid_loader.dataset.transforms.transforms = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that emptying transforms list in this way is VERY risky move. That may be solving immediate goal of adding SWI but if someone who is less familiar with implementation details runs into troubles because of this change - the odds they would curse us are pretty high.
Maybe even now we may have some incompatible scenarios where this would break (Like quantization). I really feel uneasy about this line. Is there really not other way of solving this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. It was a quick fix due to time frame concerns.
@shaydeci how should I proceed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed this one, completely agree with @BloodAxe here.
How about mentioning explicitly the new transforms we switch to when switching to SWI? Is it always going to be the case that we just empty the transforms ?
It does not make sense to me that transforms that might have nothing to do with resolution should be dropped...

Any way, we should set the transforms taken out after the validation is completed.
I think an appropriate warning/message should be printed as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we drop transforms to keep spatial size, how do we handle normalization transforms then?

Copy link
Contributor

@shaydeci shaydeci Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ithink we need to call iter otherwise this won't be updated. Have you checked that images actually pass through this empty pipeline ?

Since wach worker has its own instance of the dataset, it is required to call iter on it, this is why in our training stage switch callback for YoloX when we want to turn of the transforms:

@register_callback(Callbacks.YOLOX_TRAINING_STAGE_SWITCH)
class YoloXTrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
    """
    YoloXTrainingStageSwitchCallback

    Training stage switch for YoloX training.
    Disables mosaic, and manipulates YoloX loss to use L1.

    """

    def __init__(self, next_stage_start_epoch: int = 285):
        super(YoloXTrainingStageSwitchCallback, self).__init__(next_stage_start_epoch=next_stage_start_epoch)

    def apply_stage_change(self, context: PhaseContext):
        for transform in context.train_loader.dataset.transforms:
            if hasattr(transform, "close"):
                transform.close()
        iter(context.train_loader)
        context.criterion.use_l1 = True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed on 0ec8a9c

@Yael-Baron
Copy link
Contributor Author

Pushed a new commit (no. 1b8a739) with new recipes for each segformer's variant based on a default recipe.

@Yael-Baron
Copy link
Contributor Author

Pushed integration test for segformer's models on commit db42e6c

@@ -29,6 +29,36 @@ def test_pretrained_repvgg_a0_imagenet(self):
model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)

def test_pretrained_segformer_b0_cityscapes(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are a bit meaningless. You can take them out.

Copy link
Contributor

@shaydeci shaydeci left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on my end.

@@ -43,6 +43,10 @@ def to_one_hot(target: torch.Tensor, num_classes: int, ignore_index: int = None)
:param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot
result.
:return: one hot tensor with shape [N, num_classes, H, W]

Parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why

Parameters
-----------

and not :param ignore index?

def on_test_loader_end(self, context: PhaseContext) -> None:
unwrap_model(context.net).disable_sliding_window_validation()
context.test_loader.dataset.transforms.transforms = self.test_loader_transforms
iter(context.test_loader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I'm following why we need iter(context.test_loader) here.

@Yael-Baron Yael-Baron mentioned this pull request Aug 10, 2023
@Yael-Baron
Copy link
Contributor Author

A signed PR was opened, please follow Segformer swi signed #1361 PR.

@Yael-Baron Yael-Baron closed this Aug 10, 2023
Yael-Baron added a commit that referenced this pull request Aug 15, 2023
* Squashed version of branch Segformer_SWI (see Segformer SWI #1292 PR )

* Docstrings fix

* Docstrings fix

---------

Co-authored-by: Louis-Dupont <35190946+Louis-Dupont@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants