-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Segformer swi #1292
Segformer swi #1292
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the PR for review ?
I had a look anyway and left some comments, but just minor notes on naming/doc
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments inline, about some we have already had a discussion (so just a reminder).
Still missing:
- Unit tests - see inline.
- Integration tests, pretrained model urls.
- New phase for the final valdiation on the average model.
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
@@ -148,6 +148,7 @@ class Callbacks: | |||
DEKR_VISUALIZATION = "DEKRVisualizationCallback" | |||
ROBOFLOW_RESULT_CALLBACK = "RoboflowResultCallback" | |||
TIMER = "TimerCallback" | |||
SLIDING_WINDOW_INFERENCE = "ChangeToSWI" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the name of the callback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed on c79d0f2
@@ -856,6 +856,22 @@ def _infer_global_step(self, context: PhaseContext, is_train_loader: bool): | |||
return total_steps_in_done + train_loader_length + context.batch_idx | |||
|
|||
|
|||
@register_callback(Callbacks.SLIDING_WINDOW_INFERENCE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why choose an inconsistent name? This just adds more confusion....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed on c79d0f2
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
def on_validation_loader_start(self, context: PhaseContext) -> None: | ||
if context.training_params.max_epochs - 1 == context.epoch: | ||
unwrap_model(context.net).enable_swi() | ||
context.valid_loader.dataset.transforms.transforms = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that emptying transforms list in this way is VERY risky move. That may be solving immediate goal of adding SWI but if someone who is less familiar with implementation details runs into troubles because of this change - the odds they would curse us are pretty high.
Maybe even now we may have some incompatible scenarios where this would break (Like quantization). I really feel uneasy about this line. Is there really not other way of solving this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. It was a quick fix due to time frame concerns.
@shaydeci how should I proceed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed this one, completely agree with @BloodAxe here.
How about mentioning explicitly the new transforms we switch to when switching to SWI? Is it always going to be the case that we just empty the transforms ?
It does not make sense to me that transforms that might have nothing to do with resolution should be dropped...
Any way, we should set the transforms taken out after the validation is completed.
I think an appropriate warning/message should be printed as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we drop transforms to keep spatial size, how do we handle normalization transforms then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ithink we need to call iter otherwise this won't be updated. Have you checked that images actually pass through this empty pipeline ?
Since wach worker has its own instance of the dataset, it is required to call iter on it, this is why in our training stage switch callback for YoloX when we want to turn of the transforms:
@register_callback(Callbacks.YOLOX_TRAINING_STAGE_SWITCH)
class YoloXTrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
"""
YoloXTrainingStageSwitchCallback
Training stage switch for YoloX training.
Disables mosaic, and manipulates YoloX loss to use L1.
"""
def __init__(self, next_stage_start_epoch: int = 285):
super(YoloXTrainingStageSwitchCallback, self).__init__(next_stage_start_epoch=next_stage_start_epoch)
def apply_stage_change(self, context: PhaseContext):
for transform in context.train_loader.dataset.transforms:
if hasattr(transform, "close"):
transform.close()
iter(context.train_loader)
context.criterion.use_l1 = True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed on 0ec8a9c
…o have a consistent name for clarity. Added a new phase for the final validation on the average model. Added docstrings
Pushed a new commit (no. 1b8a739) with new recipes for each segformer's variant based on a default recipe. |
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
src/super_gradients/training/models/segmentation_models/segformer.py
Outdated
Show resolved
Hide resolved
… support sliding window inference.
…od set in segmentation_utils.py that could also be used in other models.
…ngWindowValidationCallback.
…was added to the test suit.
Pushed integration test for segformer's models on commit db42e6c |
@@ -29,6 +29,36 @@ def test_pretrained_repvgg_a0_imagenet(self): | |||
model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True}) | |||
trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True) | |||
|
|||
def test_pretrained_segformer_b0_cityscapes(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are a bit meaningless. You can take them out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM on my end.
@@ -43,6 +43,10 @@ def to_one_hot(target: torch.Tensor, num_classes: int, ignore_index: int = None) | |||
:param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot | |||
result. | |||
:return: one hot tensor with shape [N, num_classes, H, W] | |||
|
|||
Parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why
Parameters
-----------
and not :param ignore index?
def on_test_loader_end(self, context: PhaseContext) -> None: | ||
unwrap_model(context.net).disable_sliding_window_validation() | ||
context.test_loader.dataset.transforms.transforms = self.test_loader_transforms | ||
iter(context.test_loader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I'm following why we need iter(context.test_loader)
here.
A signed PR was opened, please follow Segformer swi signed #1361 PR. |
* Squashed version of branch Segformer_SWI (see Segformer SWI #1292 PR ) * Docstrings fix * Docstrings fix --------- Co-authored-by: Louis-Dupont <35190946+Louis-Dupont@users.noreply.github.com>
No description provided.