Skip to content

Conversation

fattorib
Copy link
Owner

This PR adds a fused kernel scan kernel for the forward and backward pass. It is around 2x faster than the Triton kernel I originally wrote. It should also use less memory, due to the activation recomputation in the backward pass (benchmarks pending).

Benchmarks

Also benchmarked against torch.nn.functional.scaled_dot_product_attention with hd=64

Forward Pass

bs=8,d_model=1024
image

Forward + Backward

bs=8,d_model=1024
image

@fattorib fattorib assigned fattorib and unassigned fattorib Oct 30, 2024
@svladusic
Copy link

LGTM 👍

@fattorib
Copy link
Owner Author

LGTM 👍

Thanks Stefan, appreciate your detailed review. Merging now :)

@fattorib fattorib merged commit 206a3aa into main Oct 30, 2024
2 checks passed
@fattorib fattorib deleted the fused-kernel branch November 7, 2024 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants