Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pallas): Optimize Pallas Attention + Benchmark #17328

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

jon-chuang
Copy link
Contributor

@jon-chuang jon-chuang commented Aug 28, 2023

Depends on fix: #17314 for block_q != block_k

Outdated Results: EDIT:

This one line change mimicking Triton code brought the speed for Pallas for the delayed softmax way down, to be slightly below the in-loop case.

This finnicky sensitivity to the application code is undesirable, and ideally we should not carry tech debt from Triton into Pallas, as Pallas is meant to target more than just Triton.

- acc *= alpha[:, None]
+ # Adding 0 * l_prev is due to a weird compiler bug in Triton
+ acc *= (0 * l_prev + alpha)[:, None]

image

Without exp2 optimization
image

With the full in-loop online softmax (both Pallas and Triton):
Screenshot from 2023-08-30 22-03-04

See update below for latest results.

CC: @sharadmv

@jon-chuang jon-chuang changed the title feat(examples): Add Pallas Benchmark feat(examples): Add Pallas Benchmark + DELAYED_ONLINE_SOFTMAX config Aug 30, 2023
@jon-chuang jon-chuang changed the title feat(examples): Add Pallas Benchmark + DELAYED_ONLINE_SOFTMAX config feat(pallas): Delayed Softmax Norm Trick (Flash Attention 2) + Pallas Attention Benchmark Aug 30, 2023
@jon-chuang jon-chuang force-pushed the jon-chuang/pallas-benchmark branch from 1cac649 to 9ad1478 Compare August 30, 2023 14:10
@jon-chuang jon-chuang changed the title feat(pallas): Delayed Softmax Norm Trick (Flash Attention 2) + Pallas Attention Benchmark feat(pallas): Optimize Pallas Attention (Flash Attention 2 delayed softmax norm, Triton compiler bug tricks) + Pallas Attention Benchmark Aug 30, 2023
@jon-chuang

This comment was marked as outdated.

@jon-chuang

This comment was marked as outdated.

@jon-chuang

This comment was marked as resolved.

@jon-chuang
Copy link
Contributor Author

jon-chuang commented Sep 2, 2023

Pallas TTGIR:

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
#blocked5 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked6 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @mha_forward(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c8388608_i32 = arith.constant 8388608 : i32
    %c262144_i32 = arith.constant 262144 : i32
    %c31_i32 = arith.constant 31 : i32
    %cst = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked>
    %c0_i32 = arith.constant 0 : i32
    %cst_0 = arith.constant dense<-2.38197633E+38> : tensor<128x32xf32, #blocked1>
    %cst_1 = arith.constant dense<1.44269502> : tensor<128x32xf32, #blocked1>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked1>
    %cst_3 = arith.constant dense<64> : tensor<32x1xi32, #blocked2>
    %c32_i32 = arith.constant 32 : i32
    %c1_i32 = arith.constant 1 : i32
    %cst_4 = arith.constant dense<64> : tensor<128x1xi32, #blocked1>
    %c128_i32 = arith.constant 128 : i32
    %cst_5 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked1>
    %cst_6 = arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked>
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = tt.get_program_id z : i32
    %3 = arith.muli %0, %c128_i32 : i32
    %4 = arith.muli %1, %c8388608_i32 : i32
    %5 = arith.muli %2, %c262144_i32 : i32
    %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked>
    %7 = tt.splat %3 : (i32) -> tensor<128xi32, #blocked>
    %8 = arith.addi %7, %6 : tensor<128xi32, #blocked>
    %9 = triton_gpu.convert_layout %8 : (tensor<128xi32, #blocked>) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
    %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
    %11 = arith.muli %10, %cst_4 : tensor<128x1xi32, #blocked1>
    %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
    %13 = triton_gpu.convert_layout %12 : (tensor<64xi32, #blocked>) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
    %14 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x64xi32, #blocked3>
    %15 = tt.broadcast %14 : (tensor<1x64xi32, #blocked3>) -> tensor<128x64xi32, #blocked3>
    %16 = triton_gpu.convert_layout %15 : (tensor<128x64xi32, #blocked3>) -> tensor<128x64xi32, #blocked1>
    %17 = tt.addptr %arg0, %4 : !tt.ptr<f16>, i32
    %18 = tt.addptr %17, %5 : !tt.ptr<f16>, i32
    %19 = tt.splat %18 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked1>
    %20 = tt.addptr %19, %11 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
    %21 = tt.broadcast %20 : (tensor<128x1x!tt.ptr<f16>, #blocked1>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
    %22 = tt.addptr %21, %16 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
    %23 = tt.load %22 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
    %24 = arith.addi %0, %c1_i32 : i32
    %25 = arith.muli %24, %c128_i32 : i32
    %26 = arith.addi %25, %c31_i32 : i32
    %27 = arith.divsi %26, %c32_i32 : i32
    %28 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked>
    %29 = tt.broadcast %14 : (tensor<1x64xi32, #blocked3>) -> tensor<32x64xi32, #blocked3>
    %30 = triton_gpu.convert_layout %29 : (tensor<32x64xi32, #blocked3>) -> tensor<32x64xi32, #blocked2>
    %31 = tt.addptr %arg1, %4 : !tt.ptr<f16>, i32
    %32 = tt.addptr %31, %5 : !tt.ptr<f16>, i32
    %33 = tt.splat %32 : (!tt.ptr<f16>) -> tensor<32x1x!tt.ptr<f16>, #blocked2>
    %34 = tt.addptr %arg2, %4 : !tt.ptr<f16>, i32
    %35 = tt.addptr %34, %5 : !tt.ptr<f16>, i32
    %36 = tt.splat %35 : (!tt.ptr<f16>) -> tensor<32x1x!tt.ptr<f16>, #blocked2>
    %37 = tt.broadcast %10 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1>
    %38:3 = scf.for %arg4 = %c0_i32 to %27 step %c1_i32 iter_args(%arg5 = %cst_5, %arg6 = %cst, %arg7 = %cst_6) -> (tensor<128x64xf32, #blocked1>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>)  : i32 {
      %50 = arith.muli %arg4, %c32_i32 : i32
      %51 = tt.splat %50 : (i32) -> tensor<32xi32, #blocked>
      %52 = arith.addi %51, %28 : tensor<32xi32, #blocked>
      %53 = triton_gpu.convert_layout %52 : (tensor<32xi32, #blocked>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
      %54 = tt.expand_dims %53 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1>
      %55 = triton_gpu.convert_layout %54 : (tensor<32x1xi32, #blocked1>) -> tensor<32x1xi32, #blocked2>
      %56 = arith.muli %55, %cst_3 : tensor<32x1xi32, #blocked2>
      %57 = tt.addptr %33, %56 : tensor<32x1x!tt.ptr<f16>, #blocked2>, tensor<32x1xi32, #blocked2>
      %58 = tt.broadcast %57 : (tensor<32x1x!tt.ptr<f16>, #blocked2>) -> tensor<32x64x!tt.ptr<f16>, #blocked2>
      %59 = tt.addptr %58, %30 : tensor<32x64x!tt.ptr<f16>, #blocked2>, tensor<32x64xi32, #blocked2>
      %60 = tt.load %59 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked2>
      %61 = tt.addptr %36, %56 : tensor<32x1x!tt.ptr<f16>, #blocked2>, tensor<32x1xi32, #blocked2>
      %62 = tt.broadcast %61 : (tensor<32x1x!tt.ptr<f16>, #blocked2>) -> tensor<32x64x!tt.ptr<f16>, #blocked2>
      %63 = tt.addptr %62, %30 : tensor<32x64x!tt.ptr<f16>, #blocked2>, tensor<32x64xi32, #blocked2>
      %64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked2>
      %65 = triton_gpu.convert_layout %60 : (tensor<32x64xf16, #blocked2>) -> tensor<32x64xf16, #shared>
      %66 = tt.trans %65 : (tensor<32x64xf16, #shared>) -> tensor<64x32xf16, #shared1>
      %67 = triton_gpu.convert_layout %66 : (tensor<64x32xf16, #shared1>) -> tensor<64x32xf16, #blocked4>
      %68 = triton_gpu.convert_layout %23 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>>
      %69 = triton_gpu.convert_layout %67 : (tensor<64x32xf16, #blocked4>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>>
      %70 = triton_gpu.convert_layout %cst_2 : (tensor<128x32xf32, #blocked1>) -> tensor<128x32xf32, #blocked5>
      %71 = tt.dot %68, %69, %70 {allowTF32 = false} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<128x32xf32, #blocked5>
      %72 = triton_gpu.convert_layout %71 : (tensor<128x32xf32, #blocked5>) -> tensor<128x32xf32, #blocked1>
      %73 = arith.truncf %72 : tensor<128x32xf32, #blocked1> to tensor<128x32xf16, #blocked1>
      %74 = arith.extf %73 : tensor<128x32xf16, #blocked1> to tensor<128x32xf32, #blocked1>
      %75 = arith.mulf %74, %cst_1 : tensor<128x32xf32, #blocked1>
      %76 = triton_gpu.convert_layout %52 : (tensor<32xi32, #blocked>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
      %77 = tt.expand_dims %76 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x32xi32, #blocked3>
      %78 = tt.broadcast %77 : (tensor<1x32xi32, #blocked3>) -> tensor<128x32xi32, #blocked3>
      %79 = triton_gpu.convert_layout %78 : (tensor<128x32xi32, #blocked3>) -> tensor<128x32xi32, #blocked1>
      %80 = "triton_gpu.cmpi"(%37, %79) <{predicate = 5 : i64}> : (tensor<128x32xi32, #blocked1>, tensor<128x32xi32, #blocked1>) -> tensor<128x32xi1, #blocked1>
      %81 = "triton_gpu.select"(%80, %75, %cst_0) : (tensor<128x32xi1, #blocked1>, tensor<128x32xf32, #blocked1>, tensor<128x32xf32, #blocked1>) -> tensor<128x32xf32, #blocked1>
      %82 = "tt.reduce"(%81) <{axis = 1 : i32}> ({
      ^bb0(%arg8: f32, %arg9: f32):
        %109 = "triton_gpu.cmpf"(%arg8, %arg9) <{predicate = 2 : i64}> : (f32, f32) -> i1
        %110 = arith.select %109, %arg8, %arg9 : f32
        tt.reduce.return %110 : f32
      }) : (tensor<128x32xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
      %83 = triton_gpu.convert_layout %82 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128xf32, #blocked>
      %84 = "triton_gpu.cmpf"(%83, %arg6) <{predicate = 2 : i64}> : (tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xi1, #blocked>
      %85 = "triton_gpu.select"(%84, %83, %arg6) : (tensor<128xi1, #blocked>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked>
      %86 = arith.subf %arg6, %85 : tensor<128xf32, #blocked>
      %87 = tt.pure_extern_elementwise %86 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked>
      %88 = triton_gpu.convert_layout %85 : (tensor<128xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
      %89 = tt.expand_dims %88 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xf32, #blocked1>
      %90 = tt.broadcast %89 : (tensor<128x1xf32, #blocked1>) -> tensor<128x32xf32, #blocked1>
      %91 = arith.subf %81, %90 : tensor<128x32xf32, #blocked1>
      %92 = tt.pure_extern_elementwise %91 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x32xf32, #blocked1>) -> tensor<128x32xf32, #blocked1>
      %93 = "tt.reduce"(%92) <{axis = 1 : i32}> ({
      ^bb0(%arg8: f32, %arg9: f32):
        %109 = arith.addf %arg8, %arg9 : f32
        tt.reduce.return %109 : f32
      }) : (tensor<128x32xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
      %94 = triton_gpu.convert_layout %93 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128xf32, #blocked>
      %95 = arith.mulf %87, %arg7 : tensor<128xf32, #blocked>
      %96 = arith.addf %94, %95 : tensor<128xf32, #blocked>
      %97 = arith.mulf %arg7, %cst_6 : tensor<128xf32, #blocked>
      %98 = arith.addf %97, %87 : tensor<128xf32, #blocked>
      %99 = triton_gpu.convert_layout %98 : (tensor<128xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
      %100 = tt.expand_dims %99 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xf32, #blocked1>
      %101 = tt.broadcast %100 : (tensor<128x1xf32, #blocked1>) -> tensor<128x64xf32, #blocked1>
      %102 = arith.mulf %arg5, %101 : tensor<128x64xf32, #blocked1>
      %103 = arith.truncf %92 : tensor<128x32xf32, #blocked1> to tensor<128x32xf16, #blocked1>
      %104 = triton_gpu.convert_layout %103 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked6}>>
      %105 = triton_gpu.convert_layout %64 : (tensor<32x64xf16, #blocked2>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked6}>>
      %106 = triton_gpu.convert_layout %102 : (tensor<128x64xf32, #blocked1>) -> tensor<128x64xf32, #blocked6>
      %107 = tt.dot %104, %105, %106 {allowTF32 = false} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked6}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked6}>> -> tensor<128x64xf32, #blocked6>
      %108 = triton_gpu.convert_layout %107 : (tensor<128x64xf32, #blocked6>) -> tensor<128x64xf32, #blocked1>
      scf.yield %108, %85, %96 : tensor<128x64xf32, #blocked1>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>
    }
    %39 = triton_gpu.convert_layout %38#2 : (tensor<128xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
    %40 = tt.expand_dims %39 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xf32, #blocked1>
    %41 = tt.broadcast %40 : (tensor<128x1xf32, #blocked1>) -> tensor<128x64xf32, #blocked1>
    %42 = arith.divf %38#0, %41 : tensor<128x64xf32, #blocked1>
    %43 = arith.truncf %42 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
    %44 = tt.addptr %arg3, %4 : !tt.ptr<f16>, i32
    %45 = tt.addptr %44, %5 : !tt.ptr<f16>, i32
    %46 = tt.splat %45 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked1>
    %47 = tt.addptr %46, %11 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
    %48 = tt.broadcast %47 : (tensor<128x1x!tt.ptr<f16>, #blocked1>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
    %49 = tt.addptr %48, %16 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
    tt.store %49, %43 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked1>
    tt.return
  }
}

Triton TTGIR

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @_fwd_kernel_0d1d2d34d5d6d7d8d9c10d11d12d13c14d15d16d17c18d19d20d21c2223d24d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c1_i32 = arith.constant 1 : i32
    %c3_i32 = arith.constant 3 : i32
    %c192_i64 = arith.constant 192 : i64
    %cst = arith.constant dense<128> : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %cst_0 = arith.constant dense<128> : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %c128_i32 = arith.constant 128 : i32
    %c2_i32 = arith.constant 2 : i32
    %cst_1 = arith.constant dense<64> : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %cst_2 = arith.constant dense<64> : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %c64_i32 = arith.constant 64 : i32
    %c64_i64 = arith.constant 64 : i64
    %c4_i32 = arith.constant 4 : i32
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %cst_5 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %cst_6 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma>
    %c0_i64 = arith.constant 0 : i64
    %c0_i32 = arith.constant 0 : i32
    %cst_7 = arith.constant 1.44269502 : f32
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = arith.muli %1, %arg7 : i32
    %3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
    %4 = arith.muli %0, %c128_i32 : i32
    %5 = arith.extsi %arg8 : i32 to i64
    %6 = arith.extsi %4 : i32 to i64
    %7 = tt.addptr %arg1, %2 : !tt.ptr<f16>, i32
    %8 = arith.extsi %arg11 : i32 to i64
    %9 = tt.addptr %arg2, %2 : !tt.ptr<f16>, i32
    %10 = arith.extsi %arg14 : i32 to i64
    %11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked2>
    %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %14 = tt.splat %4 : (i32) -> tensor<128xi32, #blocked2>
    %15 = tt.splat %4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %16 = arith.addi %14, %11 : tensor<128xi32, #blocked2>
    %17 = arith.addi %15, %13 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
    %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %22 = arith.mulf %arg3, %cst_7 : f32
    %23 = tt.splat %6 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %24 = arith.extsi %12 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %25 = arith.addi %23, %24 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %26 = tt.expand_dims %25 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
    %27 = tt.splat %5 : (i64) -> tensor<128x1xi64, #blocked>
    %28 = arith.muli %26, %27 : tensor<128x1xi64, #blocked>
    %29 = tt.splat %3 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked>
    %30 = tt.addptr %29, %28 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi64, #blocked>
    %31 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f16>, #blocked>) -> tensor<128x64x!tt.ptr<f16>, #blocked>
    %32 = arith.extsi %18 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %33 = arith.extsi %19 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
    %34 = arith.extsi %20 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %35 = arith.extsi %21 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %36 = tt.expand_dims %32 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi64, #blocked>
    %37 = tt.broadcast %36 : (tensor<1x64xi64, #blocked>) -> tensor<128x64xi64, #blocked>
    %38 = tt.addptr %31, %37 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked>
    %39 = tt.load %38 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked>
    %40 = tt.splat %22 : (f32) -> tensor<128x64xf32, #blocked>
    %41 = arith.extf %39 : tensor<128x64xf16, #blocked> to tensor<128x64xf32, #blocked>
    %42 = arith.mulf %41, %40 : tensor<128x64xf32, #blocked>
    %43 = arith.truncf %42 : tensor<128x64xf32, #blocked> to tensor<128x64xf16, #blocked>
    %44 = triton_gpu.convert_layout %43 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #shared>
    %45 = arith.addi %0, %c1_i32 : i32
    %46 = arith.muli %45, %c128_i32 : i32
    %47 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi64, #blocked1>
    %48 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked1>
    %49 = tt.addptr %48, %47 : tensor<64x1x!tt.ptr<f16>, #blocked1>, tensor<64x1xi64, #blocked1>
    %50 = tt.broadcast %49 : (tensor<64x1x!tt.ptr<f16>, #blocked1>) -> tensor<64x64x!tt.ptr<f16>, #blocked1>
    %51 = tt.splat %8 : (i64) -> tensor<1x64xi64, #blocked1>
    %52 = tt.splat %10 : (i64) -> tensor<64x1xi64, #blocked>
    %53 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked>
    %54 = tt.broadcast %36 : (tensor<1x64xi64, #blocked>) -> tensor<64x64xi64, #blocked>
    %55 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma>
    %56 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
    %57 = tt.expand_dims %56 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x64xi32, #mma>
    %58 = tt.broadcast %55 : (tensor<128x1xi32, #mma>) -> tensor<128x64xi32, #mma>
    %59 = arith.cmpi sgt, %46, %c0_i32 : i32
    %60 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
    %61 = arith.muli %60, %51 : tensor<1x64xi64, #blocked1>
    %62 = tt.broadcast %61 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
    %63 = tt.addptr %50, %62 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi64, #blocked1>
    %64 = triton_gpu.alloc_tensor : tensor<4x64x64xf16, #shared1>
    %65 = tt.splat %59 : (i1) -> tensor<64x64xi1, #blocked1>
    %66 = triton_gpu.insert_slice_async %63, %64, %c0_i32, %65 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked1> -> tensor<4x64x64xf16, #shared1>
    triton_gpu.async_commit_group
    %67 = tt.expand_dims %35 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
    %68 = arith.muli %67, %52 : tensor<64x1xi64, #blocked>
    %69 = tt.addptr %53, %68 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi64, #blocked>
    %70 = tt.broadcast %69 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x64x!tt.ptr<f16>, #blocked>
    %71 = tt.addptr %70, %54 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi64, #blocked>
    %72 = triton_gpu.alloc_tensor : tensor<4x64x64xf16, #shared>
    %73 = tt.splat %59 : (i1) -> tensor<64x64xi1, #blocked>
    %74 = triton_gpu.insert_slice_async %71, %72, %c0_i32, %73 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked> -> tensor<4x64x64xf16, #shared>
    triton_gpu.async_commit_group
    %75 = arith.cmpi sgt, %46, %c64_i32 : i32
    %76 = arith.addi %34, %cst_2 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %77 = tt.expand_dims %76 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
    %78 = arith.muli %77, %51 : tensor<1x64xi64, #blocked1>
    %79 = tt.broadcast %78 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
    %80 = tt.addptr %50, %79 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi64, #blocked1>
    %81 = tt.splat %75 : (i1) -> tensor<64x64xi1, #blocked1>
    %82 = triton_gpu.insert_slice_async %80, %66, %c1_i32, %81 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked1> -> tensor<4x64x64xf16, #shared1>
    triton_gpu.async_commit_group
    %83 = arith.addi %35, %cst_1 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %84 = tt.expand_dims %83 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
    %85 = arith.muli %84, %52 : tensor<64x1xi64, #blocked>
    %86 = tt.addptr %53, %85 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi64, #blocked>
    %87 = tt.broadcast %86 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x64x!tt.ptr<f16>, #blocked>
    %88 = tt.addptr %87, %54 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi64, #blocked>
    %89 = tt.splat %75 : (i1) -> tensor<64x64xi1, #blocked>
    %90 = triton_gpu.insert_slice_async %88, %74, %c1_i32, %89 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked> -> tensor<4x64x64xf16, #shared>
    triton_gpu.async_commit_group
    %91 = arith.cmpi sgt, %46, %c128_i32 : i32
    %92 = arith.addi %34, %cst_0 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %93 = tt.expand_dims %92 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
    %94 = arith.muli %93, %51 : tensor<1x64xi64, #blocked1>
    %95 = tt.broadcast %94 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
    %96 = tt.addptr %50, %95 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi64, #blocked1>
    %97 = tt.splat %91 : (i1) -> tensor<64x64xi1, #blocked1>
    %98 = triton_gpu.insert_slice_async %96, %82, %c2_i32, %97 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked1> -> tensor<4x64x64xf16, #shared1>
    triton_gpu.async_commit_group
    %99 = arith.addi %35, %cst : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %100 = tt.expand_dims %99 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
    %101 = arith.muli %100, %52 : tensor<64x1xi64, #blocked>
    %102 = tt.addptr %53, %101 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi64, #blocked>
    %103 = tt.broadcast %102 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x64x!tt.ptr<f16>, #blocked>
    %104 = tt.addptr %103, %54 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi64, #blocked>
    %105 = tt.splat %91 : (i1) -> tensor<64x64xi1, #blocked>
    %106 = triton_gpu.insert_slice_async %104, %90, %c2_i32, %105 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked> -> tensor<4x64x64xf16, #shared>
    triton_gpu.async_commit_group
    triton_gpu.async_wait {num = 4 : i32}
    %107 = triton_gpu.extract_slice %98[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared1> to tensor<64x64xf16, #shared1>
    %108 = triton_gpu.extract_slice %106[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared> to tensor<64x64xf16, #shared>
    %109:14 = scf.for %arg21 = %c0_i32 to %46 step %c64_i32 iter_args(%arg22 = %cst_3, %arg23 = %cst_4, %arg24 = %cst_5, %arg25 = %c0_i64, %arg26 = %c0_i64, %arg27 = %98, %arg28 = %106, %arg29 = %107, %arg30 = %108, %arg31 = %c192_i64, %arg32 = %c192_i64, %arg33 = %c128_i32, %arg34 = %c3_i32, %arg35 = %c1_i32) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64, tensor<4x64x64xf16, #shared1>, tensor<4x64x64xf16, #shared>, tensor<64x64xf16, #shared1>, tensor<64x64xf16, #shared>, i64, i64, i32, i32, i32)  : i32 {
      %131 = tt.splat %arg21 : (i32) -> tensor<1x64xi32, #mma>
      %132 = arith.addi %131, %57 : tensor<1x64xi32, #mma>
      %133 = tt.broadcast %132 : (tensor<1x64xi32, #mma>) -> tensor<128x64xi32, #mma>
      %134 = "triton_gpu.cmpi"(%58, %133) <{predicate = 5 : i64}> : (tensor<128x64xi32, #mma>, tensor<128x64xi32, #mma>) -> tensor<128x64xi1, #mma>
      %135 = "triton_gpu.select"(%134, %cst_3, %cst_6) : (tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
      %136 = triton_gpu.convert_layout %44 : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %137 = triton_gpu.convert_layout %arg29 : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %138 = tt.dot %136, %137, %135 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %139 = "tt.reduce"(%138) <{axis = 1 : i32}> ({
      ^bb0(%arg36: f32, %arg37: f32):
        %189 = "triton_gpu.cmpf"(%arg36, %arg37) <{predicate = 2 : i64}> : (f32, f32) -> i1
        %190 = arith.select %189, %arg36, %arg37 : f32
        tt.reduce.return %190 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %140 = "triton_gpu.cmpf"(%arg24, %139) <{predicate = 2 : i64}> : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %141 = "triton_gpu.select"(%140, %arg24, %139) : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %142 = arith.subf %arg24, %141 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %143 = tt.pure_extern_elementwise %142 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %144 = tt.expand_dims %141 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
      %145 = tt.broadcast %144 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
      %146 = arith.subf %138, %145 : tensor<128x64xf32, #mma>
      %147 = tt.pure_extern_elementwise %146 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
      %148 = arith.mulf %arg23, %cst_4 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %149 = arith.addf %148, %143 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %150 = tt.expand_dims %149 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
      %151 = tt.broadcast %150 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
      %152 = arith.mulf %arg22, %151 : tensor<128x64xf32, #mma>
      %153 = arith.truncf %147 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
      %154 = triton_gpu.convert_layout %153 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %155 = triton_gpu.convert_layout %arg30 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %156 = tt.dot %154, %155, %152 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %157 = arith.mulf %arg23, %143 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %158 = "tt.reduce"(%147) <{axis = 1 : i32}> ({
      ^bb0(%arg36: f32, %arg37: f32):
        %189 = arith.addf %arg36, %arg37 : f32
        tt.reduce.return %189 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %159 = arith.addf %157, %158 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %160 = arith.addi %arg25, %c64_i64 : i64
      %161 = arith.addi %arg26, %c64_i64 : i64
      %162 = arith.addi %arg33, %c64_i32 : i32
      %163 = arith.cmpi slt, %162, %46 : i32
      %164 = arith.remsi %arg34, %c4_i32 : i32
      %165 = arith.remsi %arg35, %c4_i32 : i32
      %166 = tt.splat %arg31 : (i64) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
      %167 = arith.addi %166, %34 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
      %168 = tt.expand_dims %167 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
      %169 = arith.muli %168, %51 : tensor<1x64xi64, #blocked1>
      %170 = tt.broadcast %169 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
      %171 = tt.addptr %50, %170 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi64, #blocked1>
      %172 = tt.splat %arg32 : (i64) -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
      %173 = arith.addi %172, %35 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
      %174 = tt.expand_dims %173 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
      %175 = arith.muli %174, %52 : tensor<64x1xi64, #blocked>
      %176 = tt.addptr %53, %175 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi64, #blocked>
      %177 = tt.broadcast %176 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x64x!tt.ptr<f16>, #blocked>
      %178 = tt.addptr %177, %54 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi64, #blocked>
      %179 = arith.addi %arg31, %c64_i64 : i64
      %180 = arith.addi %arg32, %c64_i64 : i64
      %181 = tt.splat %163 : (i1) -> tensor<64x64xi1, #blocked1>
      %182 = triton_gpu.insert_slice_async %171, %arg27, %164, %181 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked1> -> tensor<4x64x64xf16, #shared1>
      triton_gpu.async_commit_group
      %183 = tt.splat %163 : (i1) -> tensor<64x64xi1, #blocked>
      %184 = triton_gpu.insert_slice_async %178, %arg28, %164, %183 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked> -> tensor<4x64x64xf16, #shared>
      triton_gpu.async_commit_group
      triton_gpu.async_wait {num = 4 : i32}
      %185 = triton_gpu.extract_slice %182[%165, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared1> to tensor<64x64xf16, #shared1>
      %186 = triton_gpu.extract_slice %184[%165, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared> to tensor<64x64xf16, #shared>
      %187 = arith.addi %arg34, %c1_i32 : i32
      %188 = arith.addi %arg35, %c1_i32 : i32
      scf.yield %156, %159, %141, %160, %161, %182, %184, %185, %186, %179, %180, %162, %187, %188 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64, tensor<4x64x64xf16, #shared1>, tensor<4x64x64xf16, #shared>, tensor<64x64xf16, #shared1>, tensor<64x64xf16, #shared>, i64, i64, i32, i32, i32
    }
    triton_gpu.async_wait {num = 0 : i32}
    %110 = triton_gpu.convert_layout %109#2 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked2>
    %111 = triton_gpu.convert_layout %109#1 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked2>
    %112 = tt.expand_dims %109#1 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
    %113 = tt.broadcast %112 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
    %114 = arith.divf %109#0, %113 : tensor<128x64xf32, #mma>
    %115 = arith.muli %1, %arg20 : i32
    %116 = tt.addptr %arg4, %115 : !tt.ptr<f32>, i32
    %117 = tt.splat %116 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked2>
    %118 = tt.addptr %117, %16 : tensor<128x!tt.ptr<f32>, #blocked2>, tensor<128xi32, #blocked2>
    %119 = tt.pure_extern_elementwise %111 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_log2f"} : (tensor<128xf32, #blocked2>) -> tensor<128xf32, #blocked2>
    %120 = arith.addf %110, %119 : tensor<128xf32, #blocked2>
    tt.store %118, %120 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32, #blocked2>
    %121 = tt.addptr %arg5, %2 : !tt.ptr<f16>, i32
    %122 = arith.extsi %arg17 : i32 to i64
    %123 = arith.truncf %114 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
    %124 = tt.splat %122 : (i64) -> tensor<128x1xi64, #blocked>
    %125 = arith.muli %26, %124 : tensor<128x1xi64, #blocked>
    %126 = tt.splat %121 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked>
    %127 = tt.addptr %126, %125 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi64, #blocked>
    %128 = tt.broadcast %127 : (tensor<128x1x!tt.ptr<f16>, #blocked>) -> tensor<128x64x!tt.ptr<f16>, #blocked>
    %129 = tt.addptr %128, %37 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked>
    %130 = triton_gpu.convert_layout %123 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #blocked>
    tt.store %129, %130 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked>
    tt.return
  }
}

Notes:

  1. Pallas uses tt.load and has 0 insert_slice_async. Triton has 8.

@jon-chuang
Copy link
Contributor Author

jon-chuang commented Sep 4, 2023

Update

New results from updating Triton. See appendix for more details
image

Experiments

The following optimizations have been applied:

  1. Delayed softmax trick - 5% (both Triton and Pallas)
  2. Exp2 trick - 2-3%
  3. Scaling q - 1-2% - if applying exp2 trick, scaling q before the loop is faster.
    • However the error tolerance needs to be increased slightly (for dk from 0.08 to 0.1)

The following result in no appreciable difference:

  1. acc *= (0 * l_prev + alpha)[:, None]
  2. Swapping the axes for seq_len and num_heads

The following has shown improvement but has not been adopted:

  1. qk dot ordering - 2% - within 97.5% of flash_attn
    • performing the qk dot by += the mask, rather than using where directly on the product (1-2%).
    • However, this results in numerical errors (differing nans) when segment_ids are used. It passes all the other tests segment_ids=None. (see e.g. here for more info)

TODOs:

  • Benchmark bwd pass as well
  • Maybe use bar chart rather than line plot for further graphs?
  • Investigate near const overhead for Pallas
    • especially bad for small sequence length

Appendix

16x Abberant Memory Read Issue

After some digging I isolated the issue to poor handling of transpose in between load and dot (see this discussion here: triton-lang/triton#2232).

Triton did not suffer this issue as it handled the transpose as an smem layout (when loading from gmem to smem) rather than an explicit transpose of smem values.

However, by upgrading Triton, this issue no longer exists (only on backdated jax-triton dependency).

Ablations (in TFLOPs/s, higher is better)

Note: slightly inaccurate due to (now fixed) Pallas runtime overhead.

With qk dot ordering (+2%)

As we can see, Pallas is faster than Triton in the benchmark

image

According to nsight compute, the kernel itself is actually faster, and this is reflective of kernel time (and not runtime).

Screenshot from 2023-09-04 21-49-12
Screenshot from 2023-09-04 21-49-24
Screenshot from 2023-09-04 21-50-03
Screenshot from 2023-09-04 21-49-53
Screenshot from 2023-09-04 21-49-44

However, in terms of compute efficiency, flash_attn executes 2.4B instructions at 16M/scheduler, whereas pallas executes 4B at 27M/scheduler, and triton executes 3.4B at 24M/scheduler.

Here it is benched with 10 rounds of 1500 iterations each
image

Before: runtime overhead (from block_until_ready)
image

This is the same setting (with qk dot ordering) with tf32 accumulation (no appreciable difference)
image

No 0 * l_prev (both Triton and Pallas) (0%)

image

When Triton kernel applies qk dot before causal mask (-2%)

image

Without exp2 (-3%)

As you can see, it affects Triton far more than Pallas.

Screenshot from 2023-09-04 17-48-02

Without delayed softmax (-5%)

Triton also suffers
Screenshot from 2023-09-04 17-20-15

No optimizations applied (-6%)

Seems we are bottlenecked by memory so independent compute optimizations (exp2, delayed softmax) do not add linearly.
image

Archived

Previous result (before removing runtime overhead)

image

@jon-chuang jon-chuang changed the title feat(pallas): Optimize Pallas Attention (Flash Attention 2 delayed softmax norm, Triton compiler bug tricks) + Pallas Attention Benchmark feat(pallas): Optimize Pallas Attention + Pallas Attention Benchmark Sep 4, 2023
@jon-chuang

This comment was marked as resolved.

@jon-chuang
Copy link
Contributor Author

jon-chuang commented Sep 4, 2023

Analysis of Kernels (Nsight Compute)

Comparing Pallas, Triton and flash_attn kernels. For Pallas, we use the qk dot ordering version.

As we can see, Pallas is faster than Triton in the benchmark (10 rounds of 150 iterations each)

image

Here it is benched with 10 rounds of 1500 iterations each. Giving an identical result.
image

According to nsight compute, the kernel times (not including any runtime overhead) agree with this result: flash_attn < pallas < triton.

Screenshot from 2023-09-04 21-49-12
Screenshot from 2023-09-04 21-49-24
Screenshot from 2023-09-04 21-50-03
Screenshot from 2023-09-04 21-49-53
Screenshot from 2023-09-04 21-49-44

However, in terms of compute efficiency, flash_attn executes 2.4B instructions at 16M/scheduler, whereas pallas executes 4B at 27M/scheduler, and triton executes 3.4B at 24M/scheduler.

Pallas optimized TTGIR:

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.compute-capability" = 89 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @mha_forward(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c2_i32 = arith.constant 2 : i32
    %cst = arith.constant dense<-2.38197633E+38> : tensor<128x64xf32, #mma>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
    %cst_1 = arith.constant dense<2048> : tensor<128x1xi32, #blocked>
    %cst_2 = arith.constant dense<2048> : tensor<64x1xi32, #blocked>
    %cst_3 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %cst_5 = arith.constant dense<1.442380e+00> : tensor<128x64xf16, #blocked>
    %c128_i32 = arith.constant 128 : i32
    %c16777216_i32 = arith.constant 16777216 : i32
    %c64_i32 = arith.constant 64 : i32
    %c63_i32 = arith.constant 63 : i32
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = tt.get_program_id z : i32
    %3 = arith.muli %0, %c128_i32 : i32
    %4 = arith.muli %1, %c16777216_i32 : i32
    %5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %7 = tt.splat %3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %8 = tt.splat %3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %9 = arith.addi %7, %5 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %10 = arith.addi %8, %6 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %11 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi32, #blocked>
    %12 = tt.expand_dims %10 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma>
    %13 = arith.muli %11, %cst_1 : tensor<128x1xi32, #blocked>
    %14 = arith.muli %2, %c64_i32 : i32
    %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
    %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %18 = tt.expand_dims %17 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi32, #blocked>
    %19 = tt.broadcast %18 : (tensor<1x64xi32, #blocked>) -> tensor<128x64xi32, #blocked>
    %20 = tt.addptr %arg0, %4 : !tt.ptr<f16, 1>, i32
    %21 = tt.splat %20 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
    %22 = tt.addptr %21, %13 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi32, #blocked>
    %23 = tt.splat %14 : (i32) -> tensor<128x1xi32, #blocked>
    %24 = tt.addptr %22, %23 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi32, #blocked>
    %25 = tt.broadcast %24 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked>
    %26 = tt.addptr %25, %19 : tensor<128x64x!tt.ptr<f16, 1>, #blocked>, tensor<128x64xi32, #blocked>
    %27 = tt.load %26 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked>
    %28 = arith.mulf %27, %cst_5 : tensor<128x64xf16, #blocked>
    %29 = triton_gpu.convert_layout %28 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #shared>
    %30 = arith.addi %0, %c1_i32 : i32
    %31 = arith.muli %30, %c128_i32 : i32
    %32 = arith.addi %31, %c63_i32 : i32
    %33 = arith.divsi %32, %c64_i32 : i32
    %34 = tt.broadcast %18 : (tensor<1x64xi32, #blocked>) -> tensor<64x64xi32, #blocked>
    %35 = tt.addptr %arg1, %4 : !tt.ptr<f16, 1>, i32
    %36 = tt.splat %35 : (!tt.ptr<f16, 1>) -> tensor<64x1x!tt.ptr<f16, 1>, #blocked>
    %37 = tt.splat %14 : (i32) -> tensor<64x1xi32, #blocked>
    %38 = tt.addptr %arg2, %4 : !tt.ptr<f16, 1>, i32
    %39 = tt.splat %38 : (!tt.ptr<f16, 1>) -> tensor<64x1x!tt.ptr<f16, 1>, #blocked>
    %40 = tt.broadcast %12 : (tensor<128x1xi32, #mma>) -> tensor<128x64xi32, #mma>
    %41 = arith.cmpi sgt, %33, %c0_i32 : i32
    %42 = tt.expand_dims %15 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked>
    %43 = arith.muli %42, %cst_2 : tensor<64x1xi32, #blocked>
    %44 = tt.addptr %36, %43 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
    %45 = tt.addptr %44, %37 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
    %46 = tt.broadcast %45 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
    %47 = tt.addptr %46, %34 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi32, #blocked>
    %48 = triton_gpu.alloc_tensor : tensor<2x64x64xf16, #shared>
    %49 = tt.splat %41 : (i1) -> tensor<64x64xi1, #blocked>
    %50 = triton_gpu.insert_slice_async %47, %48, %c0_i32, %49 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<2x64x64xf16, #shared>
    triton_gpu.async_commit_group
    %51 = tt.addptr %39, %43 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
    %52 = tt.addptr %51, %37 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
    %53 = tt.broadcast %52 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
    %54 = tt.addptr %53, %34 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi32, #blocked>
    %55 = triton_gpu.alloc_tensor : tensor<2x64x64xf16, #shared>
    %56 = triton_gpu.insert_slice_async %54, %55, %c0_i32, %49 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<2x64x64xf16, #shared>
    triton_gpu.async_commit_group
    triton_gpu.async_wait {num = 0 : i32}
    %57 = triton_gpu.extract_slice %50[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<2x64x64xf16, #shared> to tensor<64x64xf16, #shared>
    %58 = triton_gpu.extract_slice %56[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<2x64x64xf16, #shared> to tensor<64x64xf16, #shared>
    %59:10 = scf.for %arg4 = %c0_i32 to %33 step %c1_i32 iter_args(%arg5 = %cst_0, %arg6 = %cst_3, %arg7 = %cst_4, %arg8 = %50, %arg9 = %56, %arg10 = %57, %arg11 = %58, %arg12 = %c0_i32, %arg13 = %c1_i32, %arg14 = %c0_i32) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<2x64x64xf16, #shared>, tensor<2x64x64xf16, #shared>, tensor<64x64xf16, #shared>, tensor<64x64xf16, #shared>, i32, i32, i32)  : i32 {
      %71 = arith.muli %arg4, %c64_i32 : i32
      %72 = tt.splat %71 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
      %73 = arith.addi %72, %16 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
      %74 = tt.expand_dims %73 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x64xi32, #mma>
      %75 = tt.broadcast %74 : (tensor<1x64xi32, #mma>) -> tensor<128x64xi32, #mma>
      %76 = "triton_gpu.cmpi"(%40, %75) <{predicate = 5 : i64}> : (tensor<128x64xi32, #mma>, tensor<128x64xi32, #mma>) -> tensor<128x64xi1, #mma>
      %77 = "triton_gpu.select"(%76, %cst_0, %cst) : (tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
      %78 = triton_gpu.convert_layout %29 : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %79 = tt.trans %arg10 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #shared1>
      %80 = triton_gpu.convert_layout %79 : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %81 = tt.dot %78, %80, %77 {allowTF32 = false} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %82 = arith.truncf %81 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
      %83 = arith.extf %82 : tensor<128x64xf16, #mma> to tensor<128x64xf32, #mma>
      %84 = "tt.reduce"(%83) <{axis = 1 : i32}> ({
      ^bb0(%arg15: f32, %arg16: f32):
        %131 = "triton_gpu.cmpf"(%arg15, %arg16) <{predicate = 2 : i64}> : (f32, f32) -> i1
        %132 = arith.select %131, %arg15, %arg16 : f32
        tt.reduce.return %132 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %85 = "triton_gpu.cmpf"(%84, %arg6) <{predicate = 2 : i64}> : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %86 = "triton_gpu.select"(%85, %84, %arg6) : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %87 = arith.subf %arg6, %86 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %88 = tt.extern_elementwise %87 {libname = "libdevice", libpath = "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %89 = tt.expand_dims %86 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
      %90 = tt.broadcast %89 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
      %91 = arith.subf %83, %90 : tensor<128x64xf32, #mma>
      %92 = tt.extern_elementwise %91 {libname = "libdevice", libpath = "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
      %93 = "tt.reduce"(%92) <{axis = 1 : i32}> ({
      ^bb0(%arg15: f32, %arg16: f32):
        %131 = arith.addf %arg15, %arg16 : f32
        tt.reduce.return %131 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %94 = arith.mulf %88, %arg7 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %95 = arith.addf %93, %94 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %96 = arith.mulf %arg7, %cst_4 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %97 = arith.addf %96, %88 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %98 = tt.expand_dims %97 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
      %99 = tt.broadcast %98 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
      %100 = arith.mulf %arg5, %99 : tensor<128x64xf32, #mma>
      %101 = arith.truncf %92 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
      %102 = triton_gpu.convert_layout %101 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %103 = triton_gpu.convert_layout %arg11 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %104 = tt.dot %102, %103, %100 {allowTF32 = false} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %105 = arith.addi %arg12, %c1_i32 : i32
      %106 = arith.cmpi slt, %105, %33 : i32
      %107 = arith.addi %arg14, %c1_i32 : i32
      %108 = arith.cmpi uge, %107, %c2_i32 : i32
      %109 = arith.select %108, %c0_i32, %107 : i32
      %110 = arith.muli %105, %c64_i32 : i32
      %111 = tt.splat %110 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
      %112 = arith.addi %111, %15 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
      %113 = tt.expand_dims %112 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked>
      %114 = arith.muli %113, %cst_2 : tensor<64x1xi32, #blocked>
      %115 = tt.addptr %36, %114 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
      %116 = tt.addptr %115, %37 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
      %117 = tt.broadcast %116 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
      %118 = tt.addptr %117, %34 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi32, #blocked>
      %119 = tt.addptr %39, %114 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
      %120 = tt.addptr %119, %37 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
      %121 = tt.broadcast %120 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
      %122 = tt.addptr %121, %34 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi32, #blocked>
      %123 = tt.splat %106 : (i1) -> tensor<64x64xi1, #blocked>
      %124 = triton_gpu.insert_slice_async %118, %arg8, %arg13, %123 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<2x64x64xf16, #shared>
      triton_gpu.async_commit_group
      %125 = triton_gpu.insert_slice_async %122, %arg9, %arg13, %123 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<2x64x64xf16, #shared>
      triton_gpu.async_commit_group
      triton_gpu.async_wait {num = 0 : i32}
      %126 = triton_gpu.extract_slice %124[%109, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<2x64x64xf16, #shared> to tensor<64x64xf16, #shared>
      %127 = triton_gpu.extract_slice %125[%109, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<2x64x64xf16, #shared> to tensor<64x64xf16, #shared>
      %128 = arith.addi %arg13, %c1_i32 : i32
      %129 = arith.cmpi uge, %128, %c2_i32 : i32
      %130 = arith.select %129, %c0_i32, %128 : i32
      scf.yield %104, %86, %95, %124, %125, %126, %127, %105, %130, %109 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<2x64x64xf16, #shared>, tensor<2x64x64xf16, #shared>, tensor<64x64xf16, #shared>, tensor<64x64xf16, #shared>, i32, i32, i32
    }
    triton_gpu.async_wait {num = 0 : i32}
    %60 = tt.expand_dims %59#2 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
    %61 = tt.broadcast %60 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
    %62 = arith.divf %59#0, %61 : tensor<128x64xf32, #mma>
    %63 = arith.truncf %62 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
    %64 = tt.addptr %arg3, %4 : !tt.ptr<f16, 1>, i32
    %65 = tt.splat %64 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
    %66 = tt.addptr %65, %13 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi32, #blocked>
    %67 = tt.addptr %66, %23 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi32, #blocked>
    %68 = tt.broadcast %67 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked>
    %69 = tt.addptr %68, %19 : tensor<128x64x!tt.ptr<f16, 1>, #blocked>, tensor<128x64xi32, #blocked>
    %70 = triton_gpu.convert_layout %63 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #blocked>
    tt.store %69, %70 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked>
    tt.return
  }
}

Pallas TTGIR (use smaller rank grid; reshape to (batch_size * n_heads, seq_len, head_dim)):

pallas (seq_len=8192, block_q=128, block_k=64): 29.13662592569987 ms
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.compute-capability" = 89 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @mha_forward(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c2_i32 = arith.constant 2 : i32
    %cst = arith.constant dense<-2.38197633E+38> : tensor<128x64xf32, #mma>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
    %cst_1 = arith.constant dense<64> : tensor<128x1xi32, #blocked>
    %cst_2 = arith.constant dense<64> : tensor<64x1xi32, #blocked>
    %cst_3 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %c64_i32 = arith.constant 64 : i32
    %cst_5 = arith.constant dense<1.442380e+00> : tensor<128x64xf16, #blocked>
    %c128_i32 = arith.constant 128 : i32
    %c1048576_i32 = arith.constant 1048576 : i32
    %c63_i32 = arith.constant 63 : i32
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = arith.muli %0, %c128_i32 : i32
    %3 = arith.muli %1, %c1048576_i32 : i32
    %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %6 = tt.splat %2 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %7 = tt.splat %2 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %8 = arith.addi %6, %4 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %9 = arith.addi %7, %5 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %10 = tt.expand_dims %8 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi32, #blocked>
    %11 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma>
    %12 = arith.muli %10, %cst_1 : tensor<128x1xi32, #blocked>
    %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
    %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %16 = tt.expand_dims %15 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi32, #blocked>
    %17 = tt.broadcast %16 : (tensor<1x64xi32, #blocked>) -> tensor<128x64xi32, #blocked>
    %18 = tt.addptr %arg0, %3 : !tt.ptr<f16, 1>, i32
    %19 = tt.splat %18 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
    %20 = tt.addptr %19, %12 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi32, #blocked>
    %21 = tt.broadcast %20 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked>
    %22 = tt.addptr %21, %17 : tensor<128x64x!tt.ptr<f16, 1>, #blocked>, tensor<128x64xi32, #blocked>
    %23 = tt.load %22 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked>
    %24 = arith.mulf %23, %cst_5 : tensor<128x64xf16, #blocked>
    %25 = triton_gpu.convert_layout %24 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #shared>
    %26 = arith.addi %0, %c1_i32 : i32
    %27 = arith.muli %26, %c128_i32 : i32
    %28 = arith.addi %27, %c63_i32 : i32
    %29 = arith.divsi %28, %c64_i32 : i32
    %30 = tt.broadcast %16 : (tensor<1x64xi32, #blocked>) -> tensor<64x64xi32, #blocked>
    %31 = tt.addptr %arg1, %3 : !tt.ptr<f16, 1>, i32
    %32 = tt.splat %31 : (!tt.ptr<f16, 1>) -> tensor<64x1x!tt.ptr<f16, 1>, #blocked>
    %33 = tt.addptr %arg2, %3 : !tt.ptr<f16, 1>, i32
    %34 = tt.splat %33 : (!tt.ptr<f16, 1>) -> tensor<64x1x!tt.ptr<f16, 1>, #blocked>
    %35 = tt.broadcast %11 : (tensor<128x1xi32, #mma>) -> tensor<128x64xi32, #mma>
    %36 = arith.cmpi sgt, %29, %c0_i32 : i32
    %37 = tt.expand_dims %13 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked>
    %38 = arith.muli %37, %cst_2 : tensor<64x1xi32, #blocked>
    %39 = tt.addptr %32, %38 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
    %40 = tt.broadcast %39 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
    %41 = tt.addptr %40, %30 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi32, #blocked>
    %42 = triton_gpu.alloc_tensor : tensor<2x64x64xf16, #shared>
    %43 = tt.splat %36 : (i1) -> tensor<64x64xi1, #blocked>
    %44 = triton_gpu.insert_slice_async %41, %42, %c0_i32, %43 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<2x64x64xf16, #shared>
    triton_gpu.async_commit_group
    %45 = tt.addptr %34, %38 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
    %46 = tt.broadcast %45 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
    %47 = tt.addptr %46, %30 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi32, #blocked>
    %48 = triton_gpu.alloc_tensor : tensor<2x64x64xf16, #shared>
    %49 = triton_gpu.insert_slice_async %47, %48, %c0_i32, %43 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<2x64x64xf16, #shared>
    triton_gpu.async_commit_group
    triton_gpu.async_wait {num = 0 : i32}
    %50 = triton_gpu.extract_slice %44[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<2x64x64xf16, #shared> to tensor<64x64xf16, #shared>
    %51 = triton_gpu.extract_slice %49[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<2x64x64xf16, #shared> to tensor<64x64xf16, #shared>
    %52:10 = scf.for %arg4 = %c0_i32 to %29 step %c1_i32 iter_args(%arg5 = %cst_0, %arg6 = %cst_3, %arg7 = %cst_4, %arg8 = %44, %arg9 = %49, %arg10 = %50, %arg11 = %51, %arg12 = %c0_i32, %arg13 = %c1_i32, %arg14 = %c0_i32) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<2x64x64xf16, #shared>, tensor<2x64x64xf16, #shared>, tensor<64x64xf16, #shared>, tensor<64x64xf16, #shared>, i32, i32, i32)  : i32 {
      %63 = arith.muli %arg4, %c64_i32 : i32
      %64 = tt.splat %63 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
      %65 = arith.addi %64, %14 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
      %66 = tt.expand_dims %65 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x64xi32, #mma>
      %67 = tt.broadcast %66 : (tensor<1x64xi32, #mma>) -> tensor<128x64xi32, #mma>
      %68 = "triton_gpu.cmpi"(%35, %67) <{predicate = 5 : i64}> : (tensor<128x64xi32, #mma>, tensor<128x64xi32, #mma>) -> tensor<128x64xi1, #mma>
      %69 = "triton_gpu.select"(%68, %cst_0, %cst) : (tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
      %70 = triton_gpu.convert_layout %25 : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %71 = tt.trans %arg10 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #shared1>
      %72 = triton_gpu.convert_layout %71 : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %73 = tt.dot %70, %72, %69 {allowTF32 = false} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %74 = arith.truncf %73 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
      %75 = arith.extf %74 : tensor<128x64xf16, #mma> to tensor<128x64xf32, #mma>
      %76 = "tt.reduce"(%75) <{axis = 1 : i32}> ({
      ^bb0(%arg15: f32, %arg16: f32):
        %121 = "triton_gpu.cmpf"(%arg15, %arg16) <{predicate = 2 : i64}> : (f32, f32) -> i1
        %122 = arith.select %121, %arg15, %arg16 : f32
        tt.reduce.return %122 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %77 = "triton_gpu.cmpf"(%76, %arg6) <{predicate = 2 : i64}> : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %78 = "triton_gpu.select"(%77, %76, %arg6) : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %79 = arith.subf %arg6, %78 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %80 = tt.extern_elementwise %79 {libname = "libdevice", libpath = "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %81 = tt.expand_dims %78 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
      %82 = tt.broadcast %81 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
      %83 = arith.subf %75, %82 : tensor<128x64xf32, #mma>
      %84 = tt.extern_elementwise %83 {libname = "libdevice", libpath = "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
      %85 = "tt.reduce"(%84) <{axis = 1 : i32}> ({
      ^bb0(%arg15: f32, %arg16: f32):
        %121 = arith.addf %arg15, %arg16 : f32
        tt.reduce.return %121 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %86 = arith.mulf %80, %arg7 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %87 = arith.addf %85, %86 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %88 = arith.mulf %arg7, %cst_4 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %89 = arith.addf %88, %80 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %90 = tt.expand_dims %89 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
      %91 = tt.broadcast %90 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
      %92 = arith.mulf %arg5, %91 : tensor<128x64xf32, #mma>
      %93 = arith.truncf %84 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
      %94 = triton_gpu.convert_layout %93 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %95 = triton_gpu.convert_layout %arg11 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %96 = tt.dot %94, %95, %92 {allowTF32 = false} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %97 = arith.addi %arg12, %c1_i32 : i32
      %98 = arith.cmpi slt, %97, %29 : i32
      %99 = arith.addi %arg14, %c1_i32 : i32
      %100 = arith.cmpi uge, %99, %c2_i32 : i32
      %101 = arith.select %100, %c0_i32, %99 : i32
      %102 = arith.muli %97, %c64_i32 : i32
      %103 = tt.splat %102 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
      %104 = arith.addi %103, %13 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
      %105 = tt.expand_dims %104 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked>
      %106 = arith.muli %105, %cst_2 : tensor<64x1xi32, #blocked>
      %107 = tt.addptr %32, %106 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
      %108 = tt.broadcast %107 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
      %109 = tt.addptr %108, %30 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi32, #blocked>
      %110 = tt.addptr %34, %106 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi32, #blocked>
      %111 = tt.broadcast %110 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
      %112 = tt.addptr %111, %30 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi32, #blocked>
      %113 = tt.splat %98 : (i1) -> tensor<64x64xi1, #blocked>
      %114 = triton_gpu.insert_slice_async %109, %arg8, %arg13, %113 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<2x64x64xf16, #shared>
      triton_gpu.async_commit_group
      %115 = triton_gpu.insert_slice_async %112, %arg9, %arg13, %113 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<2x64x64xf16, #shared>
      triton_gpu.async_commit_group
      triton_gpu.async_wait {num = 0 : i32}
      %116 = triton_gpu.extract_slice %114[%101, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<2x64x64xf16, #shared> to tensor<64x64xf16, #shared>
      %117 = triton_gpu.extract_slice %115[%101, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<2x64x64xf16, #shared> to tensor<64x64xf16, #shared>
      %118 = arith.addi %arg13, %c1_i32 : i32
      %119 = arith.cmpi uge, %118, %c2_i32 : i32
      %120 = arith.select %119, %c0_i32, %118 : i32
      scf.yield %96, %78, %87, %114, %115, %116, %117, %97, %120, %101 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<2x64x64xf16, #shared>, tensor<2x64x64xf16, #shared>, tensor<64x64xf16, #shared>, tensor<64x64xf16, #shared>, i32, i32, i32
    }
    triton_gpu.async_wait {num = 0 : i32}
    %53 = tt.expand_dims %52#2 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
    %54 = tt.broadcast %53 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
    %55 = arith.divf %52#0, %54 : tensor<128x64xf32, #mma>
    %56 = arith.truncf %55 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
    %57 = tt.addptr %arg3, %3 : !tt.ptr<f16, 1>, i32
    %58 = tt.splat %57 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
    %59 = tt.addptr %58, %12 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi32, #blocked>
    %60 = tt.broadcast %59 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked>
    %61 = tt.addptr %60, %17 : tensor<128x64x!tt.ptr<f16, 1>, #blocked>, tensor<128x64xi32, #blocked>
    %62 = triton_gpu.convert_layout %56 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #blocked>
    tt.store %61, %62 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked>
    tt.return
  }
}

Triton TTGIR:

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.compute-capability" = 89 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @_fwd_kernel_0d1d2d34d5d6de7de8de9c10de11de12de13c14de15de16de17c18de19de20de21c2223de24de(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
    %c3_i32 = arith.constant 3 : i32
    %c192_i64 = arith.constant 192 : i64
    %cst = arith.constant dense<128> : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %cst_0 = arith.constant dense<128> : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %c128_i32 = arith.constant 128 : i32
    %c2_i32 = arith.constant 2 : i32
    %cst_1 = arith.constant dense<64> : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %cst_2 = arith.constant dense<64> : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %c64_i32 = arith.constant 64 : i32
    %c64_i64 = arith.constant 64 : i64
    %c1_i32 = arith.constant 1 : i32
    %c4_i32 = arith.constant 4 : i32
    %cst_3 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
    %cst_5 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %cst_6 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %cst_7 = arith.constant 1.44269502 : f32
    %c0_i32 = arith.constant 0 : i32
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = arith.muli %1, %arg7 : i32
    %3 = tt.addptr %arg0, %2 : !tt.ptr<f16, 1>, i32
    %4 = arith.muli %0, %c128_i32 : i32
    %5 = arith.extsi %arg8 : i32 to i64
    %6 = arith.extsi %4 : i32 to i64
    %7 = tt.addptr %arg1, %2 : !tt.ptr<f16, 1>, i32
    %8 = arith.extsi %arg11 : i32 to i64
    %9 = tt.addptr %arg2, %2 : !tt.ptr<f16, 1>, i32
    %10 = arith.extsi %arg14 : i32 to i64
    %11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked2>
    %14 = tt.splat %4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %15 = tt.splat %4 : (i32) -> tensor<128xi32, #blocked2>
    %16 = arith.addi %14, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %17 = arith.addi %15, %13 : tensor<128xi32, #blocked2>
    %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
    %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %22 = arith.mulf %arg3, %cst_7 : f32
    %23 = tt.splat %6 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %24 = arith.extsi %11 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %25 = arith.addi %23, %24 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %26 = tt.expand_dims %25 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
    %27 = tt.splat %5 : (i64) -> tensor<128x1xi64, #blocked>
    %28 = arith.muli %26, %27 : tensor<128x1xi64, #blocked>
    %29 = tt.splat %3 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
    %30 = tt.addptr %29, %28 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi64, #blocked>
    %31 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked>
    %32 = arith.extsi %18 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %33 = arith.extsi %19 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
    %34 = arith.extsi %20 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %35 = arith.extsi %21 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %36 = tt.expand_dims %32 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi64, #blocked>
    %37 = tt.broadcast %36 : (tensor<1x64xi64, #blocked>) -> tensor<128x64xi64, #blocked>
    %38 = tt.addptr %31, %37 : tensor<128x64x!tt.ptr<f16, 1>, #blocked>, tensor<128x64xi64, #blocked>
    %39 = tt.load %38 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked>
    %40 = tt.splat %22 : (f32) -> tensor<128x64xf32, #blocked>
    %41 = arith.extf %39 : tensor<128x64xf16, #blocked> to tensor<128x64xf32, #blocked>
    %42 = arith.mulf %41, %40 : tensor<128x64xf32, #blocked>
    %43 = arith.truncf %42 : tensor<128x64xf32, #blocked> to tensor<128x64xf16, #blocked>
    %44 = triton_gpu.convert_layout %43 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #shared>
    %45 = arith.addi %0, %c1_i32 : i32
    %46 = arith.muli %45, %c128_i32 : i32
    %47 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi64, #blocked1>
    %48 = tt.splat %7 : (!tt.ptr<f16, 1>) -> tensor<64x1x!tt.ptr<f16, 1>, #blocked1>
    %49 = tt.addptr %48, %47 : tensor<64x1x!tt.ptr<f16, 1>, #blocked1>, tensor<64x1xi64, #blocked1>
    %50 = tt.broadcast %49 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked1>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked1>
    %51 = tt.splat %8 : (i64) -> tensor<1x64xi64, #blocked1>
    %52 = tt.splat %10 : (i64) -> tensor<64x1xi64, #blocked>
    %53 = tt.splat %9 : (!tt.ptr<f16, 1>) -> tensor<64x1x!tt.ptr<f16, 1>, #blocked>
    %54 = tt.broadcast %36 : (tensor<1x64xi64, #blocked>) -> tensor<64x64xi64, #blocked>
    %55 = tt.expand_dims %16 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma>
    %56 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
    %57 = tt.expand_dims %56 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x64xi32, #mma>
    %58 = tt.broadcast %55 : (tensor<128x1xi32, #mma>) -> tensor<128x64xi32, #mma>
    %59 = arith.cmpi sgt, %46, %c0_i32 : i32
    %60 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
    %61 = arith.muli %60, %51 : tensor<1x64xi64, #blocked1>
    %62 = tt.broadcast %61 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
    %63 = tt.addptr %50, %62 : tensor<64x64x!tt.ptr<f16, 1>, #blocked1>, tensor<64x64xi64, #blocked1>
    %64 = triton_gpu.alloc_tensor : tensor<4x64x64xf16, #shared1>
    %65 = tt.splat %59 : (i1) -> tensor<64x64xi1, #blocked1>
    %66 = triton_gpu.insert_slice_async %63, %64, %c0_i32, %65 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked1> -> tensor<4x64x64xf16, #shared1>
    triton_gpu.async_commit_group
    %67 = tt.expand_dims %35 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
    %68 = arith.muli %67, %52 : tensor<64x1xi64, #blocked>
    %69 = tt.addptr %53, %68 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi64, #blocked>
    %70 = tt.broadcast %69 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
    %71 = tt.addptr %70, %54 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi64, #blocked>
    %72 = triton_gpu.alloc_tensor : tensor<4x64x64xf16, #shared>
    %73 = tt.splat %59 : (i1) -> tensor<64x64xi1, #blocked>
    %74 = triton_gpu.insert_slice_async %71, %72, %c0_i32, %73 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<4x64x64xf16, #shared>
    triton_gpu.async_commit_group
    %75 = arith.cmpi sgt, %46, %c64_i32 : i32
    %76 = arith.addi %34, %cst_2 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %77 = tt.expand_dims %76 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
    %78 = arith.muli %77, %51 : tensor<1x64xi64, #blocked1>
    %79 = tt.broadcast %78 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
    %80 = tt.addptr %50, %79 : tensor<64x64x!tt.ptr<f16, 1>, #blocked1>, tensor<64x64xi64, #blocked1>
    %81 = tt.splat %75 : (i1) -> tensor<64x64xi1, #blocked1>
    %82 = triton_gpu.insert_slice_async %80, %66, %c1_i32, %81 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked1> -> tensor<4x64x64xf16, #shared1>
    triton_gpu.async_commit_group
    %83 = arith.addi %35, %cst_1 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %84 = tt.expand_dims %83 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
    %85 = arith.muli %84, %52 : tensor<64x1xi64, #blocked>
    %86 = tt.addptr %53, %85 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi64, #blocked>
    %87 = tt.broadcast %86 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
    %88 = tt.addptr %87, %54 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi64, #blocked>
    %89 = tt.splat %75 : (i1) -> tensor<64x64xi1, #blocked>
    %90 = triton_gpu.insert_slice_async %88, %74, %c1_i32, %89 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<4x64x64xf16, #shared>
    triton_gpu.async_commit_group
    %91 = arith.cmpi sgt, %46, %c128_i32 : i32
    %92 = arith.addi %34, %cst_0 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %93 = tt.expand_dims %92 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
    %94 = arith.muli %93, %51 : tensor<1x64xi64, #blocked1>
    %95 = tt.broadcast %94 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
    %96 = tt.addptr %50, %95 : tensor<64x64x!tt.ptr<f16, 1>, #blocked1>, tensor<64x64xi64, #blocked1>
    %97 = tt.splat %91 : (i1) -> tensor<64x64xi1, #blocked1>
    %98 = triton_gpu.insert_slice_async %96, %82, %c2_i32, %97 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked1> -> tensor<4x64x64xf16, #shared1>
    triton_gpu.async_commit_group
    %99 = arith.addi %35, %cst : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %100 = tt.expand_dims %99 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
    %101 = arith.muli %100, %52 : tensor<64x1xi64, #blocked>
    %102 = tt.addptr %53, %101 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi64, #blocked>
    %103 = tt.broadcast %102 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
    %104 = tt.addptr %103, %54 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi64, #blocked>
    %105 = tt.splat %91 : (i1) -> tensor<64x64xi1, #blocked>
    %106 = triton_gpu.insert_slice_async %104, %90, %c2_i32, %105 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<4x64x64xf16, #shared>
    triton_gpu.async_commit_group
    triton_gpu.async_wait {num = 4 : i32}
    %107 = triton_gpu.extract_slice %98[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared1> to tensor<64x64xf16, #shared1>
    %108 = triton_gpu.extract_slice %106[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared> to tensor<64x64xf16, #shared>
    %109:12 = scf.for %arg21 = %c0_i32 to %46 step %c64_i32 iter_args(%arg22 = %cst_4, %arg23 = %cst_6, %arg24 = %cst_5, %arg25 = %98, %arg26 = %106, %arg27 = %107, %arg28 = %108, %arg29 = %c192_i64, %arg30 = %c192_i64, %arg31 = %c128_i32, %arg32 = %c3_i32, %arg33 = %c0_i32) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<4x64x64xf16, #shared1>, tensor<4x64x64xf16, #shared>, tensor<64x64xf16, #shared1>, tensor<64x64xf16, #shared>, i64, i64, i32, i32, i32)  : i32 {
      %130 = tt.splat %arg21 : (i32) -> tensor<1x64xi32, #mma>
      %131 = arith.addi %130, %57 : tensor<1x64xi32, #mma>
      %132 = tt.broadcast %131 : (tensor<1x64xi32, #mma>) -> tensor<128x64xi32, #mma>
      %133 = "triton_gpu.cmpi"(%58, %132) <{predicate = 5 : i64}> : (tensor<128x64xi32, #mma>, tensor<128x64xi32, #mma>) -> tensor<128x64xi1, #mma>
      %134 = "triton_gpu.select"(%133, %cst_4, %cst_3) : (tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
      %135 = triton_gpu.convert_layout %44 : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %136 = triton_gpu.convert_layout %arg27 : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %138 = "tt.reduce"(%137) <{axis = 1 : i32}> ({
      ^bb0(%arg34: f32, %arg35: f32):
        %185 = arith.maxf %arg34, %arg35 : f32
        tt.reduce.return %185 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %139 = arith.maxf %arg24, %138 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %140 = arith.subf %arg24, %139 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %141 = tt.extern_elementwise %140 {libname = "libdevice", libpath = "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %142 = tt.expand_dims %139 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
      %143 = tt.broadcast %142 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
      %144 = arith.subf %137, %143 : tensor<128x64xf32, #mma>
      %145 = tt.extern_elementwise %144 {libname = "libdevice", libpath = "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
      %146 = tt.expand_dims %141 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
      %147 = tt.broadcast %146 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
      %148 = arith.mulf %arg22, %147 : tensor<128x64xf32, #mma>
      %149 = arith.truncf %145 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
      %150 = triton_gpu.convert_layout %149 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %151 = triton_gpu.convert_layout %arg28 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %152 = tt.dot %150, %151, %148 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %153 = arith.mulf %arg23, %141 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %154 = "tt.reduce"(%145) <{axis = 1 : i32}> ({
      ^bb0(%arg34: f32, %arg35: f32):
        %185 = arith.addf %arg34, %arg35 : f32
        tt.reduce.return %185 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %155 = arith.addf %153, %154 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
      %156 = arith.addi %arg31, %c64_i32 : i32
      %157 = arith.cmpi slt, %156, %46 : i32
      %158 = arith.addi %arg33, %c1_i32 : i32
      %159 = arith.cmpi uge, %158, %c4_i32 : i32
      %160 = arith.select %159, %c0_i32, %158 : i32
      %161 = tt.splat %arg29 : (i64) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
      %162 = arith.addi %161, %34 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
      %163 = tt.expand_dims %162 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
      %164 = arith.muli %163, %51 : tensor<1x64xi64, #blocked1>
      %165 = tt.broadcast %164 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
      %166 = tt.addptr %50, %165 : tensor<64x64x!tt.ptr<f16, 1>, #blocked1>, tensor<64x64xi64, #blocked1>
      %167 = tt.splat %arg30 : (i64) -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
      %168 = arith.addi %167, %35 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
      %169 = tt.expand_dims %168 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
      %170 = arith.muli %169, %52 : tensor<64x1xi64, #blocked>
      %171 = tt.addptr %53, %170 : tensor<64x1x!tt.ptr<f16, 1>, #blocked>, tensor<64x1xi64, #blocked>
      %172 = tt.broadcast %171 : (tensor<64x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x64x!tt.ptr<f16, 1>, #blocked>
      %173 = tt.addptr %172, %54 : tensor<64x64x!tt.ptr<f16, 1>, #blocked>, tensor<64x64xi64, #blocked>
      %174 = arith.addi %arg29, %c64_i64 : i64
      %175 = arith.addi %arg30, %c64_i64 : i64
      %176 = tt.splat %157 : (i1) -> tensor<64x64xi1, #blocked1>
      %177 = triton_gpu.insert_slice_async %166, %arg25, %arg32, %176 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked1> -> tensor<4x64x64xf16, #shared1>
      triton_gpu.async_commit_group
      %178 = tt.splat %157 : (i1) -> tensor<64x64xi1, #blocked>
      %179 = triton_gpu.insert_slice_async %173, %arg26, %arg32, %178 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16, 1>, #blocked> -> tensor<4x64x64xf16, #shared>
      triton_gpu.async_commit_group
      triton_gpu.async_wait {num = 4 : i32}
      %180 = triton_gpu.extract_slice %177[%160, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared1> to tensor<64x64xf16, #shared1>
      %181 = triton_gpu.extract_slice %179[%160, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared> to tensor<64x64xf16, #shared>
      %182 = arith.addi %arg32, %c1_i32 : i32
      %183 = arith.cmpi uge, %182, %c4_i32 : i32
      %184 = arith.select %183, %c0_i32, %182 : i32
      scf.yield %152, %155, %139, %177, %179, %180, %181, %174, %175, %156, %184, %160 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<4x64x64xf16, #shared1>, tensor<4x64x64xf16, #shared>, tensor<64x64xf16, #shared1>, tensor<64x64xf16, #shared>, i64, i64, i32, i32, i32
    }
    triton_gpu.async_wait {num = 0 : i32}
    %110 = tt.expand_dims %109#1 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
    %111 = tt.broadcast %110 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
    %112 = arith.divf %109#0, %111 : tensor<128x64xf32, #mma>
    %113 = arith.muli %1, %arg20 : i32
    %114 = tt.addptr %arg4, %113 : !tt.ptr<f32, 1>, i32
    %115 = tt.splat %114 : (!tt.ptr<f32, 1>) -> tensor<128x!tt.ptr<f32, 1>, #blocked2>
    %116 = tt.addptr %115, %17 : tensor<128x!tt.ptr<f32, 1>, #blocked2>, tensor<128xi32, #blocked2>
    %117 = tt.extern_elementwise %109#1 {libname = "libdevice", libpath = "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_log2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %118 = arith.addf %109#2, %117 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
    %119 = triton_gpu.convert_layout %118 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked2>
    tt.store %116, %119 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32, #blocked2>
    %120 = tt.addptr %arg5, %2 : !tt.ptr<f16, 1>, i32
    %121 = arith.extsi %arg17 : i32 to i64
    %122 = arith.truncf %112 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
    %123 = tt.splat %121 : (i64) -> tensor<128x1xi64, #blocked>
    %124 = arith.muli %26, %123 : tensor<128x1xi64, #blocked>
    %125 = tt.splat %120 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
    %126 = tt.addptr %125, %124 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi64, #blocked>
    %127 = tt.broadcast %126 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked>
    %128 = tt.addptr %127, %37 : tensor<128x64x!tt.ptr<f16, 1>, #blocked>, tensor<128x64xi64, #blocked>
    %129 = triton_gpu.convert_layout %122 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #blocked>
    tt.store %128, %129 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked>
    tt.return
  }
}

The memory pattern also differs quite dramatically.

Pallas and Triton are nearly identical:
Screenshot from 2023-09-04 22-02-30
Screenshot from 2023-09-04 22-02-42

But flash_attn uses a lot more local loads and stores rather than global load (in other words, register spilling)
Screenshot from 2023-09-04 22-03-29

Screenshot from 2023-09-04 22-03-15

Since NVCC and Triton compiler both utilize LLVM to generate PTX, it is interesting to note how flash_attn kernel seems to rely on more local.ld/st, and uses more registers (255 registers v.s. 240, 227).

@jon-chuang jon-chuang changed the title feat(pallas): Optimize Pallas Attention + Pallas Attention Benchmark feat(pallas): Optimize Pallas Attention + Benchmark Sep 4, 2023
ptillet pushed a commit to triton-lang/triton that referenced this pull request Sep 5, 2023
This bug previously existed and I verified it in previously nightly
release of triton (20230714).

However, according to new benchmarks, this bug no longer exists on
Triton main. See:
jax-ml/jax#17328 (comment)
@sharadmv sharadmv self-assigned this Sep 6, 2023
@sharadmv
Copy link
Collaborator

Thank you for this amazing investigation! I am not actually very experienced with GPU performance and benchmarking so we really appreciate you doing this. Let's try to get your optimizations in once we update Triton internally. Does that sound good?

@jon-chuang
Copy link
Contributor Author

jon-chuang commented Sep 14, 2023

Hello @sharadmv, no problem.

There's a lot of data to parse through but I hope to help shed light on the state of things by generating and making available this data.

@jon-chuang
Copy link
Contributor Author

jon-chuang commented Sep 14, 2023

Updates (round 2)

Applied new optimizations:

  1. Causal Masking First + No-Mask Loop Body
  2. FP16 acc for the P @ V matmul. See here on ongoing discussion about numerical accuracy/stability

Exploring:

  1. FP16 accumulation for Q @ K matmul
    • not committing due to questions on numerical stability
    • However, performance gains (extra 20%) are very tempting

Benchmark Settings

Total rounds=150, Outer runs=5, GPU=RTX 4070 Laptop (sm_89)

I was unable to bench flash_attn this time due to install dependency issues. However, we will use Triton as the common comparator.

Summary

  1. FP16 helps a lot (caveat: probably only for consumer GPU, not datacenter cards)
  2. Delayed softmax reciprocal seems to produce no measured difference in this setting (I previously measured a 5% difference in the fp32 accumulation setting. Hypothesis - we are even more memory bottlenecked than before. Solution - explore even more pipeline stages)

Have yet to explore numerical accuracy (I think in the first place numerical accuracy may be lower for Triton/Pallas)

Findings

In-loop reciprocal (highest expected stability and accuracy) - 1.21x speedup

image

Delayed softmax reciprocal + Rely on 1/seq_len for numerical stability - Also 1.21x speedup

Explanation: Faster compute = more memory / smem read bottlenecked. So we no longer see 5% speedup when reducing compute intensity further. Action: one can try to increase pipeline stages in order to better feed compute.

No-Mask Loop Body=True, PV_f16_acc=True, DELAYED_SOFTMAX_NORMALIZE=True
image

Ablation: always apply mask - slightly slower (-1%)

No-Mask Loop Body=False, PV_f16_acc=True, DELAYED_SOFTMAX_NORMALIZE=True
Screenshot from 2023-09-15 06-40-35

Ablation: no PV fp16 acc

No-Mask Loop Body=True, PV_f16_acc=False, DELAYED_SOFTMAX_NORMALIZE=True
image

Exploration: QK fp16 acc - 1.41x speedup

image

@jon-chuang
Copy link
Contributor Author

jon-chuang commented Sep 14, 2023

Update 3:

Explorations with more pipeline stages (in the memory bound setting)

Conclusions: pipelining indeed helps the memory bound (more FP16 acc) setting. Delayed softmax does help when there is more software prefetching.

Delayed Softmax, 4 stage async read pipeline - 1.22x speedup

image

Delayed Softmax, 5 stage async read pipeline - 1.22x speedup

image

In-Loop Softmax, 4 stage async read pipeline - 1.17x speedup

image

In-Loop Softmax, 5 stage async read pipeline - 1.17x speedup

image

Exploration: QK fp16 acc, 4 stage async read pipeline - 1.44x speedup

image

Exploration: QK fp16 acc, 5 stage async read pipeline - 1.46x speedup

image

Conclusion

Very excited to explore the FP8 setting with even more software pipelining!

@jon-chuang
Copy link
Contributor Author

jon-chuang commented Sep 15, 2023

It was determined: for benchmarking on commercial card (e.g. RTX 4070 sm_89) FP16 accumulate setting is more reflective of data center perf, particularly because FP32 acc perf is crippled on commercial cards (see e.g. ref1, ref2).

In particular:

  • Tensor Core TFLOPs on my device (184 Tensor Cores, 46 SMs) is:
    • FP16 inputs, FP32 accumulation: 56 TFLOPs (achieved TFLOPs: 35.7, compute utilization: 63%)
    • FP16 inputs, FP16 accumulation: 118 TFLOPs (achieved TFLOPs: 53.2, compute utilization: 45%)
  • In comparison, on data center cards:
    • A100 cards (432 Tensor Cores, 108 SMs, Max TFlops: 312, TFLOps: 184, CU: 58%)
    • H100 SXM50 cards (512 Tensor Cores, 128 SMs, Max TFlops: 1000, TFLOPs: 284, CU: 28.4%)
  • Notes: TFLOPs calc ignores non-matmul FLOPs since matmul flops dominate by a factor of ~30. That being said, non matmul FLOPs are much more expensive (4-16x for based on dtype/device).
    • Non-matmul FLOPs are more expensive on H100 (FP32: 4x on A100/RTX, 16x on H100)

My next line of investigation will be FP8 matmul. It was noted cuDNN FP8 fused attention is slow for small seq len, likely due to scale/unscale overhead. On H100, FP8 matmul + FP16 arithmetic is extremely interesting direction.

pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
This bug previously existed and I verified it in previously nightly
release of triton (20230714).

However, according to new benchmarks, this bug no longer exists on
Triton main. See:
jax-ml/jax#17328 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants