Skip to content

Commit

Permalink
[Refactor] Change the meaning of momentum in EMA (#1581)
Browse files Browse the repository at this point in the history
* fix ema momentum meaning, results in codes and config changes

* fix comments

* Update mmedit/engine/hooks/ema.py

Co-authored-by: Yanhong Zeng <zengyh1900@gmail.com>

* complete warning for old user

* fix lint

* fix ut

* test warning

* pytest capture warning

* fix lint

Co-authored-by: Yanhong Zeng <zengyh1900@gmail.com>
  • Loading branch information
plyfager and zengyh1900 authored Jan 18, 2023
1 parent 328c875 commit 6ed0f25
Show file tree
Hide file tree
Showing 35 changed files with 104 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
2 changes: 1 addition & 1 deletion configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
start_iter=1000)

model = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))))

optim_wrapper = dict(
generator=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))))

optim_wrapper = dict(
generator=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@
type='ExponentialMovingAverageHook',
module_keys=('generator_ema'),
interval=1,
interp_cfg=dict(momentum=0.999),
interp_cfg=dict(momentum=0.001),
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,6 @@
type='ExponentialMovingAverageHook',
module_keys=('generator_ema'),
interval=1,
interp_cfg=dict(momentum=0.999),
interp_cfg=dict(momentum=0.001),
)
]
3 changes: 2 additions & 1 deletion configs/styleganv1/styleganv1_ffhq-1024x1024_8xb4-25Mimgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# MODEL
model_wrapper_cfg = dict(find_unused_parameters=True)
ema_half_life = 10. # G_smoothing_kimg
ema_config = dict(interval=1, momentum=0.5**(32. / (ema_half_life * 1000.)))
ema_config = dict(
interval=1, momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))
model = dict(
generator=dict(out_size=1024),
discriminator=dict(in_size=1024),
Expand Down
3 changes: 2 additions & 1 deletion configs/styleganv1/styleganv1_ffhq-256x256_8xb4-25Mimgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# MODEL
model_wrapper_cfg = dict(find_unused_parameters=True)
ema_half_life = 10. # G_smoothing_kimg
ema_config = dict(interval=1, momentum=0.5**(32. / (ema_half_life * 1000.)))
ema_config = dict(
interval=1, momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))
model = dict(
generator=dict(out_size=256),
discriminator=dict(in_size=256),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
2 changes: 1 addition & 1 deletion configs/styleganv2/stylegan2_c2_8xb4_ffhq-1024x1024.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
2 changes: 1 addition & 1 deletion configs/styleganv2/stylegan2_c2_8xb4_lsun-car-384x512.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
36 changes: 25 additions & 11 deletions mmedit/engine/hooks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,36 @@ def __init__(self,
getattr(self, interp_mode), **self.interp_cfg)

@staticmethod
def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True):
"""This is the function to perform linear interpolation between a and
b.
def lerp(a, b, momentum=0.001, momentum_nontrainable=1., trainable=True):
"""Does a linear interpolation of two parameters/ buffers.
Args:
a (float): number a
b (float): bumber b
momentum (float, optional): momentum. Defaults to 0.999.
momentum_nontrainable (float, optional): Defaults to 0.
trainable (bool, optional): trainable flag. Defaults to True.
a (torch.Tensor): Interpolation start point, refer to orig state.
b (torch.Tensor): Interpolation end point, refer to ema state.
momentum (float, optional): The weight for the interpolation
formula. Defaults to 0.001.
momentum_nontrainable (float, optional): The weight for the
interpolation formula used for nontrainable parameters.
Defaults to 1..
trainable (bool, optional): Whether input parameters is trainable.
If set to False, momentum_nontrainable will be used.
Defaults to True.
Returns:
_type_: _description_
torch.Tensor: Interpolation result.
"""
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}'
assert 0.0 < momentum_nontrainable <= 1.0, (
'momentum_nontrainable must be in range (0.0, 1.0] but got '
f'{momentum_nontrainable}')
if momentum > 0.5:
warnings.warn(
'The value of momentum in EMA is usually a small number,'
'which is different from the conventional notion of '
f'momentum but got {momentum}. Please make sure the '
f'value is correct.')
m = momentum if trainable else momentum_nontrainable
return a + (b - a) * m
return b + (a - b) * m

def every_n_iters(self, runner: Runner, n: int):
"""This is the function to perform every n iterations.
Expand Down
13 changes: 11 additions & 2 deletions mmedit/models/base_models/average_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def __init__(self,
super().__init__(model, interval, device, update_buffers)
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}'
if momentum > 0.5:
warnings.warn(
'The value of momentum in EMA is usually a small number,'
'which is different from the conventional notion of '
f'momentum but got {momentum}. Please make sure the '
f'value is correct.')
self.momentum = momentum

def avg_func(self, averaged_param: Tensor, source_param: Tensor,
Expand Down Expand Up @@ -230,8 +236,11 @@ def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps (int): The number of times the parameters have been
updated.
"""
momentum = self.rampup(self.steps, self.ema_kimg, self.ema_rampup,
self.batch_size, self.eps)
momentum = 1. - self.rampup(self.steps, self.ema_kimg, self.ema_rampup,
self.batch_size, self.eps)
if not (0.0 < momentum < 1.0):
warnings.warn('RampUp momentum must be in range (0.0, 1.0)'
f'but got {momentum}')
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)

def _load_from_state_dict(self, state_dict: dict, prefix: str,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_engine/test_hooks/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ def test_ema_hook(self):
assert torch.equal(runner.model.module_a.a, torch.tensor([0.25, 0.5]))
assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75]))

# test warning
with pytest.warns(UserWarning):
default_config = dict(
module_keys=('module_a_ema', 'module_b_ema'),
interval=1,
interp_cfg=dict(momentum=0.6))
cfg_ = deepcopy(default_config)
ema = ExponentialMovingAverageHook(**cfg_)
ema.lerp(
torch.tensor([0.25, 0.5]),
torch.tensor([0.25, 0.5]),
momentum=0.6)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_ema_hook_cuda(self):
ema = ExponentialMovingAverageHook(**self.default_config)
Expand Down
Loading

0 comments on commit 6ed0f25

Please sign in to comment.