diff --git a/opacus/validators/batch_norm.py b/opacus/validators/batch_norm.py index 7f6cd20f..58891482 100644 --- a/opacus/validators/batch_norm.py +++ b/opacus/validators/batch_norm.py @@ -15,6 +15,7 @@ import logging from typing import List, Union +import math import torch.nn as nn @@ -86,7 +87,7 @@ def _batchnorm_to_groupnorm(module: BATCHNORM) -> nn.GroupNorm: paper *Group Normalization* https://arxiv.org/abs/1803.08494 """ return nn.GroupNorm( - min(32, module.num_features), module.num_features, affine=module.affine + math.gcd(32, module.num_features), module.num_features, affine=module.affine )