Skip to content

Commit

Permalink
Fix dtype mismatch in fused_linear_cross_entropy_forward
Browse files Browse the repository at this point in the history
Fixes linkedin#305

Fix dtype mismatch in fused_linear_cross_entropy_forward function.

* Cast `logits_chunk` to the data type of `_input_chunk` before performing operations on it.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/linkedin/Liger-Kernel/issues/305?shareId=XXXX-XXXX-XXXX-XXXX).
  • Loading branch information
kostum123 committed Oct 12, 2024
1 parent ff6650b commit d4504c4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def fused_linear_cross_entropy_forward(
n_non_ignore = (target_chunk != ignore_index).sum().item()

# when doing CE, use the upcasted precision
logits_chunk = logits_chunk.float()
logits_chunk = logits_chunk.to(_input_chunk.dtype)

# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
Expand Down

0 comments on commit d4504c4

Please sign in to comment.