Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Change the meaning of momentum in EMA #1581

Merged
merged 13 commits into from
Jan 18, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
2 changes: 1 addition & 1 deletion configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
start_iter=1000)

model = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))))

optim_wrapper = dict(
generator=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))))

optim_wrapper = dict(
generator=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@
type='ExponentialMovingAverageHook',
module_keys=('generator_ema'),
interval=1,
interp_cfg=dict(momentum=0.999),
interp_cfg=dict(momentum=0.001),
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,6 @@
type='ExponentialMovingAverageHook',
module_keys=('generator_ema'),
interval=1,
interp_cfg=dict(momentum=0.999),
interp_cfg=dict(momentum=0.001),
)
]
3 changes: 2 additions & 1 deletion configs/styleganv1/styleganv1_ffhq-1024x1024_8xb4-25Mimgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# MODEL
model_wrapper_cfg = dict(find_unused_parameters=True)
ema_half_life = 10. # G_smoothing_kimg
ema_config = dict(interval=1, momentum=0.5**(32. / (ema_half_life * 1000.)))
ema_config = dict(
interval=1, momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))
model = dict(
generator=dict(out_size=1024),
discriminator=dict(in_size=1024),
Expand Down
3 changes: 2 additions & 1 deletion configs/styleganv1/styleganv1_ffhq-256x256_8xb4-25Mimgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# MODEL
model_wrapper_cfg = dict(find_unused_parameters=True)
ema_half_life = 10. # G_smoothing_kimg
ema_config = dict(interval=1, momentum=0.5**(32. / (ema_half_life * 1000.)))
ema_config = dict(
interval=1, momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))
model = dict(
generator=dict(out_size=256),
discriminator=dict(in_size=256),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
2 changes: 1 addition & 1 deletion configs/styleganv2/stylegan2_c2_8xb4_ffhq-1024x1024.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
2 changes: 1 addition & 1 deletion configs/styleganv2/stylegan2_c2_8xb4_lsun-car-384x512.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
36 changes: 25 additions & 11 deletions mmedit/engine/hooks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,36 @@ def __init__(self,
getattr(self, interp_mode), **self.interp_cfg)

@staticmethod
def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True):
"""This is the function to perform linear interpolation between a and
b.
def lerp(a, b, momentum=0.001, momentum_nontrainable=1., trainable=True):
"""Does a linear interpolation of two parameters/ buffers.

Args:
a (float): number a
b (float): bumber b
momentum (float, optional): momentum. Defaults to 0.999.
momentum_nontrainable (float, optional): Defaults to 0.
trainable (bool, optional): trainable flag. Defaults to True.

a (torch.Tensor): Interpolation start point, refer to orig state.
b (torch.Tensor): Interpolation end point, refer to ema state.
momentum (float, optional): The weight for the interpolation
formula. Defaults to 0.001.
momentum_nontrainable (float, optional): The weight for the
interpolation formula used for nontrainable parameters.
Defaults to 1..
trainable (bool, optional): Whether input parameters is trainable.
If set to False, momentum_nontrainable will be used.
Defaults to True.
Returns:
_type_: _description_
torch.Tensor: Interpolation result.
"""
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}'
assert 0.0 < momentum_nontrainable <= 1.0, (
'momentum_nontrainable must be in range (0.0, 1.0] but got '
f'{momentum_nontrainable}')
if momentum > 0.5:
warnings.warn(
'The value of momentum in EMA is usually a small number,'
'which is different from the conventional notion of '
f'momentum but got {momentum}. Please make sure the '
f'value is correct.')
m = momentum if trainable else momentum_nontrainable
return a + (b - a) * m
return b + (a - b) * m
plyfager marked this conversation as resolved.
Show resolved Hide resolved

def every_n_iters(self, runner: Runner, n: int):
"""This is the function to perform every n iterations.
Expand Down
13 changes: 11 additions & 2 deletions mmedit/models/base_models/average_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def __init__(self,
super().__init__(model, interval, device, update_buffers)
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}'
if momentum > 0.5:
warnings.warn(
'The value of momentum in EMA is usually a small number,'
'which is different from the conventional notion of '
f'momentum but got {momentum}. Please make sure the '
f'value is correct.')
self.momentum = momentum

def avg_func(self, averaged_param: Tensor, source_param: Tensor,
Expand Down Expand Up @@ -230,8 +236,11 @@ def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps (int): The number of times the parameters have been
updated.
"""
momentum = self.rampup(self.steps, self.ema_kimg, self.ema_rampup,
self.batch_size, self.eps)
momentum = 1. - self.rampup(self.steps, self.ema_kimg, self.ema_rampup,
plyfager marked this conversation as resolved.
Show resolved Hide resolved
self.batch_size, self.eps)
if not (0.0 < momentum < 1.0):
warnings.warn('RampUp momentum must be in range (0.0, 1.0)'
f'but got {momentum}')
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)

def _load_from_state_dict(self, state_dict: dict, prefix: str,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_engine/test_hooks/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ def test_ema_hook(self):
assert torch.equal(runner.model.module_a.a, torch.tensor([0.25, 0.5]))
assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75]))

# test warning
with pytest.warns(UserWarning):
default_config = dict(
module_keys=('module_a_ema', 'module_b_ema'),
interval=1,
interp_cfg=dict(momentum=0.6))
cfg_ = deepcopy(default_config)
ema = ExponentialMovingAverageHook(**cfg_)
ema.lerp(
torch.tensor([0.25, 0.5]),
torch.tensor([0.25, 0.5]),
momentum=0.6)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_ema_hook_cuda(self):
ema = ExponentialMovingAverageHook(**self.default_config)
Expand Down
Loading