Skip to content

Commit

Permalink
Revert "Revert "export.py: fix custom SDPA type conversion logic & re…
Browse files Browse the repository at this point in the history
…-enable for bflo…" (pytorch#1197)" (pytorch#1199)
  • Loading branch information
swolchok authored Sep 25, 2024
1 parent c34efd5 commit 8446605
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
input_pos[-1].item(),
seqlen,
)
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype)
return self.wo(output)

def replace_attention_with_custom_sdpa_attention(module: nn.Module):
Expand Down Expand Up @@ -291,11 +291,7 @@ def export_for_et(model, device, output_path) -> str:
model = model.to(dtype=target_precision)
state_dict_dtype = target_precision

# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
# support anything but bfloat32, and our attempt to use it anyway by converting
# to and from float causes other errors.)
if target_precision != torch.bfloat16:
replace_attention_with_custom_sdpa_attention(model)
replace_attention_with_custom_sdpa_attention(model)

with torch.nn.attention.sdpa_kernel(
[torch.nn.attention.SDPBackend.MATH]
Expand Down

0 comments on commit 8446605

Please sign in to comment.