diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 48eb4b78..c6ae7002 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -603,7 +603,7 @@ def backward(ctx, grad_output, *args): if ctx.needs_input_grad[1]: if grad_weight_sync: grad_weight_sync.wait() - if grad_bias and grad_bias_sync: + if grad_bias is not None and grad_bias_sync is not None: grad_bias_sync.wait() return grad_input, grad_weight, grad_bias, None, None, None, None, None, None