Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set the default to use set_to_none for clearing gradients in BF16 optimizer. #5434

Merged
merged 8 commits into from
Apr 23, 2024
15 changes: 12 additions & 3 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads):

# clear gradients
if clear_lp_grads:
lp.grad._zero()
lp.grad.zero_()
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

@torch.no_grad()
def _update_hp_grads_func(self, clear_lp_grads=False):
Expand Down Expand Up @@ -441,11 +441,20 @@ def clear_hp_grads(self):
self.fp32_groups_has_gradients[i] = [False] * len(group)

def clear_lp_grads(self):

# using zero_() fixed memory address for graph replay
set_to_none = set_to_none = False if self.graph_harvesting else True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
set_to_none = set_to_none = False if self.graph_harvesting else True
set_to_none = False if self.graph_harvesting else True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed and reverified

zero_grads_list = []
for group in self.bf16_groups:
for param in group:
if param.grad is not None:
# Using zero_() fixed memory address for graph replay
param.grad.zero_()
assert param.grad.grad_fn == None
if set_to_none:
param.grad = None
elif param.grad is not None:
zero_grads_list.append(param.grad)
if not set_to_none and len(zero_grads_list) > 0:
torch._foreach_zero_(zero_grads_list)

def state_dict(self):
state_dict = {}
Expand Down
Loading