FedOpt algorithm not working as expected in cifar10 example #2314
Replies: 24 comments 5 replies
-
Thank you for trying out and raising the issue! It would be nice if you can share your other experiments figure to benefit other people. @holgerroth can you help answer this question, thanks |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
Interesting. Just to confirm, are you using momentum on the server when using FedOpt (see here). That could explain a different behavior to FedAvg. |
Beta Was this translation helpful? Give feedback.
-
Ye @holgerroth, I tried using momentum with different values and I also tried to don't use it. Even if the results were changing and I was obtaining better results with some values compared to others, they were still bad results like I reached at max 0.5 acc that is pretty low compared with the other algorithms. |
Beta Was this translation helpful? Give feedback.
-
That's interesting. So, the problem only comes up when using the pretrained CNN? FedOpt seems to be more sensitive to this initialization. Have you tried reducing the local aggregation_epochs? |
Beta Was this translation helpful? Give feedback.
-
Ye, even reducing the local epochs of each client the behaviour stays the same (obviously worst due to the less epochs). I also tried using MobieNetv2 and ResNet18 with the same settings explained before but without param |
Beta Was this translation helpful? Give feedback.
-
Hi @LeandroDiL, do you have any updates on this topic? |
Beta Was this translation helpful? Give feedback.
-
Hi @holgerroth, I can confirm there is problematic behaviour when using anything other than the ModerateCNN and SimpleCNN. Global model validation metrics get stuck at 0.1 from the first round of aggregation. |
Beta Was this translation helpful? Give feedback.
-
I see. Can you specify what models and alpha setting you are using? Are the same models working fine with FedAvg and the same alpha setting on CIFAR-10? |
Beta Was this translation helpful? Give feedback.
-
Yes, this is with alpha 0.6. FedAvg & FedProx work fine. It's a dozen of models, from a ResNet-20 to a couple of Transformers, all of them break under FedOpt except for ModerateCNN and to some extent SimpleCNN. SimpleCNN underperforms, but at least it converges. The rest get stuck in terms of global model validation accuracy, but locally they do learn (local validation accuracy increases between agreggations). All trained from scratch. |
Beta Was this translation helpful? Give feedback.
-
Ok. Have you tried different learning rates and momentum for the fedopt optimizer, maybe even some optimizers other than SGD? lr 1 and momentum 0 should behave identically to FedAvg with SGD optimizer. |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
@siomvas thanks for more information, |
Beta Was this translation helpful? Give feedback.
-
Hi @LeandroDiL, @siomvas, I'm looking into this issue now. Just to confirm, have you also changed the model configuration in config_fed_server.json when running these experiments? Please attach your job configurations and code if possible. |
Beta Was this translation helpful? Give feedback.
-
Okay, I was able to reproduce the behavior. It has to do with the batch norm layers of these more complex models. When updating the global model using SGD, the batch norm parameters are actually not included in The FedOpt paper also uses group norm instead of batch norm to avoid these kinds of issues: I provided a workaround for this issue by updating the batch norm parameters using FedAvg and only updating the trainable parameters using the FedOpt optimizer for the global model: #1851 |
Beta Was this translation helpful? Give feedback.
-
I pinpointed this issue/bug when trying to use SCAFFOLD since that actually (conveniently) breaks, so I could see where the error was, I will open a new bug report for that. This is what I found: The issue is not with batch norm itself, but with the running stats: >>> [k for k,v in mobile.named_parameters()][:5]
['conv1.weight', 'bn1.weight', 'bn1.bias', 'layers.0.conv1.weight', 'layers.0.bn1.weight']
>>> [k for k in mobile.state_dict()][:7]
['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layers.0.conv1.weight'] The weight and bias of BN are getting averaged, but the running stats don't, causing a non-sensical layer (note that num_batches_tracked is not used in any calculation in the default setting where BN uses momentum instead), as these are learned in-tandem client-side. It also applies to other architectural elements too; SWIN has a Correct me if I'm wrong but it seems with #1851 the weights and biases are still getting "FedOpted", while the running stats get averaged, so this should not be expected behaviour as there will be a mismatch. A quick test with FedAdam using the proposed hparams from the FedOpt paper (client lr=0.03, server lr=0.01) using #1851 shows there is convergence, but how the mismatch in the affected layers affects model performance is unclear. To investigate further, I tried combining FedOpt with FedBN, implemented via the task filter mechanism (adding Exclude_vars for bn parameters). But it seems currently there is another bug where FedOpt does not respect task filters. See fix which can be added in #1851. for name, param in self.model.named_parameters():
param.grad = torch.tensor(-1.0 * model_diff[name]).to(self.device)
updated_params.append(name) should be for name, param in self.model.named_parameters():
if name in model_diff:
param.grad = torch.tensor(-1.0 * model_diff[name]).to(self.device)
updated_params.append(name) I believe this should remain open as not a bug but a documented issue. |
Beta Was this translation helpful? Give feedback.
-
Hi @siomvas, thanks for your test and additional info. Yes, the desired behavior of batch norm layers with FedOpt is somewhat unclear. That's why many try to avoid using batch norm in FL settings as in the FedOpt paper and why I used "workaround" to describe #1851 as it will use FedOpt to optimize the global trainable parameters but use FedAvg to update any other layers such as batch norm statistics. It needs to be seen if this approach also works with SWIN architectures. I know it's inconvenient as most of the pretrained torchvision models use batch norm but I would recommend looking into models that use group norm instead. Thanks for pointing out the issue with using filters, I added that fix to the PR. I also updated the doc string to document the behavior when using batch norm. It's acceptable to me as we can match the performance of FedAvg using this workaround and the equivalent SGD settings (lr=1, momentum=0). |
Beta Was this translation helpful? Give feedback.
-
Hi, @siomvas I found the similar situation as yours. I used Adam as an optimizer and Swin as the model for Cifar10. However, with the first epoch done, the loss, acc1, and acc5 never got better. I changed to a much smaller model ResNet56 from timm, and the results got very good as expected. |
Beta Was this translation helpful? Give feedback.
-
@BitCalSaul, converted this to an open discussion around FedOpt. Did you confirm that the SWIN architecture can train a good model in centralized training? From your curves it looks like it doesn't converge at all. |
Beta Was this translation helpful? Give feedback.
-
Hi @holgerroth , I googled this situation and found this issue. I'm not sure if I understand centralized training correct, but I guess it means training a model in one GPU? |
Beta Was this translation helpful? Give feedback.
-
Hello, I can attest to i) the nvflare bug having been resolved ii) SWIN-T with random init weights performing around the 40% mark on CIFAR-10 with the inputs upscaled to (224,224). iii) SWIN-T with Imagenet weights getting >90%. Without any code snippet it's difficult to comment on what is wrong with your implementation. Does FedAvg work? I am not familiar with the paper or the repo you mentioned and they don't seem to be placed in the FL context, if you are looking to understand the learning dynamics of SotA architectures when used for FL you might be interested in the following recent work (disclaimer, the second paper is mine): |
Beta Was this translation helpful? Give feedback.
-
Hi @holgerroth @siomvas , the issue has been addressed. The code for the implementation of Swin came from the official repo. I tried to use it to do Cifar10. The bad results or said the stuck loss came from the layernorm in the PatchMerging module. When I removed this layernorm, the loss didn't get stuck anymore. It claims that the model for big dataset doesn't necessarily work well in the small dataset. Thank you for your guys attention. |
Beta Was this translation helpful? Give feedback.
-
Hi @holgerroth, I even froze all the batch normalization layers during the model update in FedOpt and updated them using the FedAvg strategy. However, the performance drop still occurs. It is weird |
Beta Was this translation helpful? Give feedback.
-
@falibabaei, yes, FedOpt can be a bit tricky to get to work. I would first recommend setting |
Beta Was this translation helpful? Give feedback.
-
Describe the bug
The FedOpt algorithm is not working as expected in cifar10 example when I change the model from the pre-existing ModerateCNN to another model like MobileNetv2 or Resnet18 and others. The problem is that the accuracy of the global model is not increasing or increasing too slow with the FedOpt algorithm while the other algorithms works just fine even changing the model.
To Reproduce
Add in 'cifar10_nets.py' the new model :
class MyModel(nn.Module):
def init(self):
super(MyModel, self).init()
model = models.mobilenet_v2(weights='DEFAULT')
model.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(1280, 10),
)
self.model = model
def forward(self, x):
return self.model(x)
Import and change the model in file 'cifar10_learner.py'
Launch the example with
./run_simulator.sh cifar10_fedopt 0.1 8 8
See the results in tensorboard with
tensorboard --logdir=/tmp/nvflare/sim_cifar10
under the section 'val_acc_global_model'Expected behavior
I expect reading the algorithm proposed in Reddi, Sashank, et al. "Adaptive federated optimization." arXiv preprint arXiv:2003.00295 (2020), to obtain the same performance of FedAvg using SGD optimizer with lr = 1.0 and no scheduler. Also obtain better results changing optimizer and adding a scheduler.
Screenshots
Purple = FedAvg
Pink = FedOpt
Desktop (please complete the following information):
Ty in advance!
Beta Was this translation helpful? Give feedback.
All reactions