Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Change the meaning of momentum in EMA #1581

Merged
merged 13 commits into from
Jan 18, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
2 changes: 1 addition & 1 deletion configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
start_iter=1000)

model = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.9999,
momentum=0.0001,
update_buffers=True,
start_iter=20000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

model = dict(
type='MSPIEStyleGAN2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))))

optim_wrapper = dict(
generator=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))))

optim_wrapper = dict(
generator=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@
type='ExponentialMovingAverageHook',
module_keys=('generator_ema'),
interval=1,
interp_cfg=dict(momentum=0.999),
interp_cfg=dict(momentum=0.001),
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,6 @@
type='ExponentialMovingAverageHook',
module_keys=('generator_ema'),
interval=1,
interp_cfg=dict(momentum=0.999),
interp_cfg=dict(momentum=0.001),
)
]
3 changes: 2 additions & 1 deletion configs/styleganv1/styleganv1_ffhq-1024x1024_8xb4-25Mimgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# MODEL
model_wrapper_cfg = dict(find_unused_parameters=True)
ema_half_life = 10. # G_smoothing_kimg
ema_config = dict(interval=1, momentum=0.5**(32. / (ema_half_life * 1000.)))
ema_config = dict(
interval=1, momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))
model = dict(
generator=dict(out_size=1024),
discriminator=dict(in_size=1024),
Expand Down
3 changes: 2 additions & 1 deletion configs/styleganv1/styleganv1_ffhq-256x256_8xb4-25Mimgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# MODEL
model_wrapper_cfg = dict(find_unused_parameters=True)
ema_half_life = 10. # G_smoothing_kimg
ema_config = dict(interval=1, momentum=0.5**(32. / (ema_half_life * 1000.)))
ema_config = dict(
interval=1, momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))
model = dict(
generator=dict(out_size=256),
discriminator=dict(in_size=256),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
2 changes: 1 addition & 1 deletion configs/styleganv2/stylegan2_c2_8xb4_ffhq-1024x1024.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
2 changes: 1 addition & 1 deletion configs/styleganv2/stylegan2_c2_8xb4_lsun-car-384x512.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ema_config=dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.))),
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
Expand Down
29 changes: 18 additions & 11 deletions mmedit/engine/hooks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,29 @@ def __init__(self,
getattr(self, interp_mode), **self.interp_cfg)

@staticmethod
def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True):
"""This is the function to perform linear interpolation between a and
b.
def lerp(a, b, momentum=0.001, momentum_nontrainable=1., trainable=True):
"""Does a linear interpolation of two parameters/ buffers.

Args:
a (float): number a
b (float): bumber b
momentum (float, optional): momentum. Defaults to 0.999.
momentum_nontrainable (float, optional): Defaults to 0.
trainable (bool, optional): trainable flag. Defaults to True.

a (torch.Tensor): Interpolation start point, refer to orig state.
b (torch.Tensor): Interpolation end point, refer to ema state.
momentum (float, optional): The weight for the interpolation
formula. Defaults to 0.001.
momentum_nontrainable (float, optional): The weight for the
interpolation formula used for nontrainable parameters.
Defaults to 1..
trainable (bool, optional): Whether input parameters is trainable.
If set to False, momentum_nontrainable will be used.
Defaults to True.
Returns:
_type_: _description_
torch.Tensor: Interpolation result.
"""
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}'
assert 0.0 < momentum_nontrainable <= 1.0, 'momentum_nontrainable'\
plyfager marked this conversation as resolved.
Show resolved Hide resolved
f'must be in range (0.0, 1.0] but got {momentum_nontrainable}'
m = momentum if trainable else momentum_nontrainable
return a + (b - a) * m
return b + (a - b) * m
plyfager marked this conversation as resolved.
Show resolved Hide resolved

def every_n_iters(self, runner: Runner, n: int):
"""This is the function to perform every n iterations.
Expand Down
7 changes: 5 additions & 2 deletions mmedit/models/base_models/average_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,11 @@ def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps (int): The number of times the parameters have been
updated.
"""
momentum = self.rampup(self.steps, self.ema_kimg, self.ema_rampup,
self.batch_size, self.eps)
momentum = 1. - self.rampup(self.steps, self.ema_kimg, self.ema_rampup,
plyfager marked this conversation as resolved.
Show resolved Hide resolved
self.batch_size, self.eps)
if not (0.0 < momentum < 1.0):
warnings.warn('RampUp momentum must be in range (0.0, 1.0)'
f'but got {momentum}')
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)

def _load_from_state_dict(self, state_dict: dict, prefix: str,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_base_models/test_average_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TestExponentialMovingAverage(TestCase):
@classmethod
def setUpClass(cls):
cls.default_cfg = dict(
interval=1, momentum=0.9999, update_buffers=True)
interval=1, momentum=0.0001, update_buffers=True)

def test_init(self):
cfg = deepcopy(self.default_cfg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setup_class(cls):
cls.ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

cls.loss_config = dict(
r1_loss_weight=10. / 2. * d_reg_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setup_class(cls):
cls.ema_config = dict(
type='ExponentialMovingAverage',
interval=1,
momentum=0.5**(32. / (ema_half_life * 1000.)))
momentum=1. - (0.5**(32. / (ema_half_life * 1000.))))

cls.loss_config = dict(
r1_loss_weight=10. / 2. * d_reg_interval,
Expand Down