-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add params check when convert BatchNorm to GroupNorm #390
add params check when convert BatchNorm to GroupNorm #390
Conversation
@JohnlNguyen has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hey, these still fail our unit tests |
@JohnlNguyen , the failed case try to convert BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True) to GroupNorm, but it is invalid,how do we ignore this conversion? |
opacus/validators/batch_norm.py
Outdated
if module.num_features % min(32, module.num_features) != 0: | ||
raise UnsupportableModuleError( | ||
"There is no equivalent GroupNorm module to replace BatchNorm with." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if module.num_features % min(32, module.num_features) != 0: | |
raise UnsupportableModuleError( | |
"There is no equivalent GroupNorm module to replace BatchNorm with." | |
) |
Raising an error seems very prohibitive. The default value of 32 was chosen based on the empirical results in the paper, but using a different value for num_groups is still better than disallowing conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what was the previous behavior? You would get an error message if and only if you actually ran the module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan, yes, you only get an error when running the module before.
opacus/validators/batch_norm.py
Outdated
if module.num_features % min(32, module.num_features) != 0: | ||
raise UnsupportableModuleError( | ||
"There is no equivalent GroupNorm module to replace BatchNorm with." | ||
) | ||
return nn.GroupNorm( | ||
min(32, module.num_features), module.num_features, affine=module.affine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min(32, module.num_features), module.num_features, affine=module.affine | |
gcd(32, module.num_features), module.num_features, affine=module.affine |
How about replacing min with gcd? This should work for any number of channels, and will distill to InstanceNorm in the extreme case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the code is changed to use the gcd method.
@XiaobingSuper has updated the pull request. You must reimport the pull request before landing. |
@karthikprasad has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This PR will do params check when converting BatchNorm to GroupNorm, because GroupNorm should have some pre-request at its initiation step.