Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TR-DPO bug #1971

Closed
Bodoral opened this issue Aug 26, 2024 · 5 comments · Fixed by #1978
Closed

TR-DPO bug #1971

Bodoral opened this issue Aug 26, 2024 · 5 comments · Fixed by #1978

Comments

@Bodoral
Copy link

Bodoral commented Aug 26, 2024

I'm running dpo with sync_ref_model on. the dpo training script is adopted from alignment-handbook with deepspeed stage 3 enabled. however the training crashed at the point of updating reference model weights, here is the error I got

  File "/root/miniconda3/envs/llm_alignment_env/lib/python3.10/site-packages/trl/trainer/utils.py", line 102, in on_step_end             
    self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha)                                                            
  File "/root/miniconda3/envs/llm_alignment_env/lib/python3.10/site-packages/trl/trainer/utils.py", line 92, in sync_target_model        
    SyncRefModelCallback._sync_target_model(model, target_model, alpha)                                                                  
  File "/root/miniconda3/envs/llm_alignment_env/lib/python3.10/site-packages/trl/trainer/utils.py", line 84, in _sync_target_model       
    target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)                                                               
RuntimeError: The size of tensor a (0) must match the size of tensor b (4096) at non-singleton dimension 1

after checking SyncRefModelCallback it turns out that it only gathers the parameters of the dpo model being trained but not the reference model.

    @staticmethod
    def _sync_target_model(model, target_model, alpha):
        for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
            target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)

    @staticmethod
    def sync_target_model(model, target_model, alpha):
        deepspeed_plugin = AcceleratorState().deepspeed_plugin
        if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
            with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0):
                if deepspeed.comm.get_rank() == 0:
                    SyncRefModelCallback._sync_target_model(model, target_model, alpha)
        else:
            SyncRefModelCallback._sync_target_model(model, target_model, alpha)

I fixed that by first gathering the dpo model being trained and save its parameters data in a list then gather the reference model and update its parameters.

I'm not sure if this is actually a bug or there is something wrong with my settings

OS type: 20.04.1-Ubuntu
Python: 3.10
packages:
trl==0.9.6
torch==2.3.1
torchaudio==2.4.0
torchvision==0.19.0
transformers==4.44.1
deepspeed==0.14.5

@kashif
Copy link
Collaborator

kashif commented Aug 26, 2024

could be an edge case... do you mind sending a PR to check?

@kashif
Copy link
Collaborator

kashif commented Aug 26, 2024

do you mean something like this:

    @staticmethod
    def sync_target_model(model, target_model, alpha):
        deepspeed_plugin = AcceleratorState().deepspeed_plugin
        if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
            with deepspeed.zero.GatheredParameters(
                list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
            ):
                if deepspeed.comm.get_rank() == 0:
                    SyncRefModelCallback._sync_target_model(model, target_model, alpha)
        else:
            SyncRefModelCallback._sync_target_model(model, target_model, alpha)

@Bodoral
Copy link
Author

Bodoral commented Aug 28, 2024

here is how I solved it

   @staticmethod
   def _sync_target_model_deepspeed(model_parameters, target_model, alpha):
       for target_param, copy_param in zip(target_model.parameters(),model_parameters):
           target_param.data.mul_(1.0 - alpha).add_(copy_param, alpha=alpha)

   @staticmethod
   def _sync_target_model(model, target_model, alpha):
       for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
           target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)

   @staticmethod
   def _gather_model_parameters(model):
       model_parameters = []
       with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0):
               if deepspeed.comm.get_rank() == 0:
                   for p in model.parameters():
                      model_parameters.append(p.data)
       return model_parameters

   @staticmethod
   def sync_target_model(model, target_model, alpha):
       deepspeed_plugin = AcceleratorState().deepspeed_plugin
       if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
           model_parameters = SyncRefModelCallback._gather_model_parameters(model)
           with deepspeed.zero.GatheredParameters(list(target_model.parameters()), modifier_rank=0):
               if deepspeed.comm.get_rank() == 0:
                   SyncRefModelCallback._sync_target_model_deepspeed(model_parameters, target_model, alpha)
       else:
           SyncRefModelCallback._sync_target_model(model, target_model, alpha)

@kashif
Copy link
Collaborator

kashif commented Aug 28, 2024

@Bodoral can you check if it works with the current head?

@Bodoral
Copy link
Author

Bodoral commented Sep 8, 2024

    @staticmethod
    def sync_target_model(model, target_model, alpha):
        deepspeed_plugin = AcceleratorState().deepspeed_plugin
        if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
            with deepspeed.zero.GatheredParameters(
                list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
            ):
                if deepspeed.comm.get_rank() == 0:
                    SyncRefModelCallback._sync_target_model(model, target_model, alpha)
        else:
            SyncRefModelCallback._sync_target_model(model, target_model, alpha)

tested this and it works perfectly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants