-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pallas): Optimize Pallas Attention + Benchmark #17328
base: main
Are you sure you want to change the base?
Conversation
…fix-causal-upper-bound
DELAYED_ONLINE_SOFTMAX
config
DELAYED_ONLINE_SOFTMAX
config1cac649
to
9ad1478
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as resolved.
This comment was marked as resolved.
Pallas TTGIR: #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
#blocked5 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked6 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @mha_forward(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c8388608_i32 = arith.constant 8388608 : i32
%c262144_i32 = arith.constant 262144 : i32
%c31_i32 = arith.constant 31 : i32
%cst = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked>
%c0_i32 = arith.constant 0 : i32
%cst_0 = arith.constant dense<-2.38197633E+38> : tensor<128x32xf32, #blocked1>
%cst_1 = arith.constant dense<1.44269502> : tensor<128x32xf32, #blocked1>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked1>
%cst_3 = arith.constant dense<64> : tensor<32x1xi32, #blocked2>
%c32_i32 = arith.constant 32 : i32
%c1_i32 = arith.constant 1 : i32
%cst_4 = arith.constant dense<64> : tensor<128x1xi32, #blocked1>
%c128_i32 = arith.constant 128 : i32
%cst_5 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked1>
%cst_6 = arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked>
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = tt.get_program_id z : i32
%3 = arith.muli %0, %c128_i32 : i32
%4 = arith.muli %1, %c8388608_i32 : i32
%5 = arith.muli %2, %c262144_i32 : i32
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked>
%7 = tt.splat %3 : (i32) -> tensor<128xi32, #blocked>
%8 = arith.addi %7, %6 : tensor<128xi32, #blocked>
%9 = triton_gpu.convert_layout %8 : (tensor<128xi32, #blocked>) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
%11 = arith.muli %10, %cst_4 : tensor<128x1xi32, #blocked1>
%12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
%13 = triton_gpu.convert_layout %12 : (tensor<64xi32, #blocked>) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
%14 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x64xi32, #blocked3>
%15 = tt.broadcast %14 : (tensor<1x64xi32, #blocked3>) -> tensor<128x64xi32, #blocked3>
%16 = triton_gpu.convert_layout %15 : (tensor<128x64xi32, #blocked3>) -> tensor<128x64xi32, #blocked1>
%17 = tt.addptr %arg0, %4 : !tt.ptr<f16>, i32
%18 = tt.addptr %17, %5 : !tt.ptr<f16>, i32
%19 = tt.splat %18 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked1>
%20 = tt.addptr %19, %11 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
%21 = tt.broadcast %20 : (tensor<128x1x!tt.ptr<f16>, #blocked1>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%22 = tt.addptr %21, %16 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%23 = tt.load %22 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%24 = arith.addi %0, %c1_i32 : i32
%25 = arith.muli %24, %c128_i32 : i32
%26 = arith.addi %25, %c31_i32 : i32
%27 = arith.divsi %26, %c32_i32 : i32
%28 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked>
%29 = tt.broadcast %14 : (tensor<1x64xi32, #blocked3>) -> tensor<32x64xi32, #blocked3>
%30 = triton_gpu.convert_layout %29 : (tensor<32x64xi32, #blocked3>) -> tensor<32x64xi32, #blocked2>
%31 = tt.addptr %arg1, %4 : !tt.ptr<f16>, i32
%32 = tt.addptr %31, %5 : !tt.ptr<f16>, i32
%33 = tt.splat %32 : (!tt.ptr<f16>) -> tensor<32x1x!tt.ptr<f16>, #blocked2>
%34 = tt.addptr %arg2, %4 : !tt.ptr<f16>, i32
%35 = tt.addptr %34, %5 : !tt.ptr<f16>, i32
%36 = tt.splat %35 : (!tt.ptr<f16>) -> tensor<32x1x!tt.ptr<f16>, #blocked2>
%37 = tt.broadcast %10 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1>
%38:3 = scf.for %arg4 = %c0_i32 to %27 step %c1_i32 iter_args(%arg5 = %cst_5, %arg6 = %cst, %arg7 = %cst_6) -> (tensor<128x64xf32, #blocked1>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) : i32 {
%50 = arith.muli %arg4, %c32_i32 : i32
%51 = tt.splat %50 : (i32) -> tensor<32xi32, #blocked>
%52 = arith.addi %51, %28 : tensor<32xi32, #blocked>
%53 = triton_gpu.convert_layout %52 : (tensor<32xi32, #blocked>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%54 = tt.expand_dims %53 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1>
%55 = triton_gpu.convert_layout %54 : (tensor<32x1xi32, #blocked1>) -> tensor<32x1xi32, #blocked2>
%56 = arith.muli %55, %cst_3 : tensor<32x1xi32, #blocked2>
%57 = tt.addptr %33, %56 : tensor<32x1x!tt.ptr<f16>, #blocked2>, tensor<32x1xi32, #blocked2>
%58 = tt.broadcast %57 : (tensor<32x1x!tt.ptr<f16>, #blocked2>) -> tensor<32x64x!tt.ptr<f16>, #blocked2>
%59 = tt.addptr %58, %30 : tensor<32x64x!tt.ptr<f16>, #blocked2>, tensor<32x64xi32, #blocked2>
%60 = tt.load %59 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked2>
%61 = tt.addptr %36, %56 : tensor<32x1x!tt.ptr<f16>, #blocked2>, tensor<32x1xi32, #blocked2>
%62 = tt.broadcast %61 : (tensor<32x1x!tt.ptr<f16>, #blocked2>) -> tensor<32x64x!tt.ptr<f16>, #blocked2>
%63 = tt.addptr %62, %30 : tensor<32x64x!tt.ptr<f16>, #blocked2>, tensor<32x64xi32, #blocked2>
%64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked2>
%65 = triton_gpu.convert_layout %60 : (tensor<32x64xf16, #blocked2>) -> tensor<32x64xf16, #shared>
%66 = tt.trans %65 : (tensor<32x64xf16, #shared>) -> tensor<64x32xf16, #shared1>
%67 = triton_gpu.convert_layout %66 : (tensor<64x32xf16, #shared1>) -> tensor<64x32xf16, #blocked4>
%68 = triton_gpu.convert_layout %23 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>>
%69 = triton_gpu.convert_layout %67 : (tensor<64x32xf16, #blocked4>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>>
%70 = triton_gpu.convert_layout %cst_2 : (tensor<128x32xf32, #blocked1>) -> tensor<128x32xf32, #blocked5>
%71 = tt.dot %68, %69, %70 {allowTF32 = false} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<128x32xf32, #blocked5>
%72 = triton_gpu.convert_layout %71 : (tensor<128x32xf32, #blocked5>) -> tensor<128x32xf32, #blocked1>
%73 = arith.truncf %72 : tensor<128x32xf32, #blocked1> to tensor<128x32xf16, #blocked1>
%74 = arith.extf %73 : tensor<128x32xf16, #blocked1> to tensor<128x32xf32, #blocked1>
%75 = arith.mulf %74, %cst_1 : tensor<128x32xf32, #blocked1>
%76 = triton_gpu.convert_layout %52 : (tensor<32xi32, #blocked>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
%77 = tt.expand_dims %76 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x32xi32, #blocked3>
%78 = tt.broadcast %77 : (tensor<1x32xi32, #blocked3>) -> tensor<128x32xi32, #blocked3>
%79 = triton_gpu.convert_layout %78 : (tensor<128x32xi32, #blocked3>) -> tensor<128x32xi32, #blocked1>
%80 = "triton_gpu.cmpi"(%37, %79) <{predicate = 5 : i64}> : (tensor<128x32xi32, #blocked1>, tensor<128x32xi32, #blocked1>) -> tensor<128x32xi1, #blocked1>
%81 = "triton_gpu.select"(%80, %75, %cst_0) : (tensor<128x32xi1, #blocked1>, tensor<128x32xf32, #blocked1>, tensor<128x32xf32, #blocked1>) -> tensor<128x32xf32, #blocked1>
%82 = "tt.reduce"(%81) <{axis = 1 : i32}> ({
^bb0(%arg8: f32, %arg9: f32):
%109 = "triton_gpu.cmpf"(%arg8, %arg9) <{predicate = 2 : i64}> : (f32, f32) -> i1
%110 = arith.select %109, %arg8, %arg9 : f32
tt.reduce.return %110 : f32
}) : (tensor<128x32xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%83 = triton_gpu.convert_layout %82 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128xf32, #blocked>
%84 = "triton_gpu.cmpf"(%83, %arg6) <{predicate = 2 : i64}> : (tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xi1, #blocked>
%85 = "triton_gpu.select"(%84, %83, %arg6) : (tensor<128xi1, #blocked>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked>
%86 = arith.subf %arg6, %85 : tensor<128xf32, #blocked>
%87 = tt.pure_extern_elementwise %86 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked>
%88 = triton_gpu.convert_layout %85 : (tensor<128xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%89 = tt.expand_dims %88 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xf32, #blocked1>
%90 = tt.broadcast %89 : (tensor<128x1xf32, #blocked1>) -> tensor<128x32xf32, #blocked1>
%91 = arith.subf %81, %90 : tensor<128x32xf32, #blocked1>
%92 = tt.pure_extern_elementwise %91 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x32xf32, #blocked1>) -> tensor<128x32xf32, #blocked1>
%93 = "tt.reduce"(%92) <{axis = 1 : i32}> ({
^bb0(%arg8: f32, %arg9: f32):
%109 = arith.addf %arg8, %arg9 : f32
tt.reduce.return %109 : f32
}) : (tensor<128x32xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%94 = triton_gpu.convert_layout %93 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128xf32, #blocked>
%95 = arith.mulf %87, %arg7 : tensor<128xf32, #blocked>
%96 = arith.addf %94, %95 : tensor<128xf32, #blocked>
%97 = arith.mulf %arg7, %cst_6 : tensor<128xf32, #blocked>
%98 = arith.addf %97, %87 : tensor<128xf32, #blocked>
%99 = triton_gpu.convert_layout %98 : (tensor<128xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%100 = tt.expand_dims %99 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xf32, #blocked1>
%101 = tt.broadcast %100 : (tensor<128x1xf32, #blocked1>) -> tensor<128x64xf32, #blocked1>
%102 = arith.mulf %arg5, %101 : tensor<128x64xf32, #blocked1>
%103 = arith.truncf %92 : tensor<128x32xf32, #blocked1> to tensor<128x32xf16, #blocked1>
%104 = triton_gpu.convert_layout %103 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked6}>>
%105 = triton_gpu.convert_layout %64 : (tensor<32x64xf16, #blocked2>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked6}>>
%106 = triton_gpu.convert_layout %102 : (tensor<128x64xf32, #blocked1>) -> tensor<128x64xf32, #blocked6>
%107 = tt.dot %104, %105, %106 {allowTF32 = false} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked6}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked6}>> -> tensor<128x64xf32, #blocked6>
%108 = triton_gpu.convert_layout %107 : (tensor<128x64xf32, #blocked6>) -> tensor<128x64xf32, #blocked1>
scf.yield %108, %85, %96 : tensor<128x64xf32, #blocked1>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>
}
%39 = triton_gpu.convert_layout %38#2 : (tensor<128xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%40 = tt.expand_dims %39 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xf32, #blocked1>
%41 = tt.broadcast %40 : (tensor<128x1xf32, #blocked1>) -> tensor<128x64xf32, #blocked1>
%42 = arith.divf %38#0, %41 : tensor<128x64xf32, #blocked1>
%43 = arith.truncf %42 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
%44 = tt.addptr %arg3, %4 : !tt.ptr<f16>, i32
%45 = tt.addptr %44, %5 : !tt.ptr<f16>, i32
%46 = tt.splat %45 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked1>
%47 = tt.addptr %46, %11 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
%48 = tt.broadcast %47 : (tensor<128x1x!tt.ptr<f16>, #blocked1>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%49 = tt.addptr %48, %16 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
tt.store %49, %43 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked1>
tt.return
}
} Triton TTGIR #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @_fwd_kernel_0d1d2d34d5d6d7d8d9c10d11d12d13c14d15d16d17c18d19d20d21c2223d24d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c1_i32 = arith.constant 1 : i32
%c3_i32 = arith.constant 3 : i32
%c192_i64 = arith.constant 192 : i64
%cst = arith.constant dense<128> : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%cst_0 = arith.constant dense<128> : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%c128_i32 = arith.constant 128 : i32
%c2_i32 = arith.constant 2 : i32
%cst_1 = arith.constant dense<64> : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%cst_2 = arith.constant dense<64> : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%c64_i32 = arith.constant 64 : i32
%c64_i64 = arith.constant 64 : i64
%c4_i32 = arith.constant 4 : i32
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%cst_4 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%cst_5 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%cst_6 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma>
%c0_i64 = arith.constant 0 : i64
%c0_i32 = arith.constant 0 : i32
%cst_7 = arith.constant 1.44269502 : f32
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
%4 = arith.muli %0, %c128_i32 : i32
%5 = arith.extsi %arg8 : i32 to i64
%6 = arith.extsi %4 : i32 to i64
%7 = tt.addptr %arg1, %2 : !tt.ptr<f16>, i32
%8 = arith.extsi %arg11 : i32 to i64
%9 = tt.addptr %arg2, %2 : !tt.ptr<f16>, i32
%10 = arith.extsi %arg14 : i32 to i64
%11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked2>
%12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%14 = tt.splat %4 : (i32) -> tensor<128xi32, #blocked2>
%15 = tt.splat %4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%16 = arith.addi %14, %11 : tensor<128xi32, #blocked2>
%17 = arith.addi %15, %13 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%22 = arith.mulf %arg3, %cst_7 : f32
%23 = tt.splat %6 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%24 = arith.extsi %12 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%25 = arith.addi %23, %24 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%26 = tt.expand_dims %25 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
%27 = tt.splat %5 : (i64) -> tensor<128x1xi64, #blocked>
%28 = arith.muli %26, %27 : tensor<128x1xi64, #blocked>
%29 = tt.splat %3 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked>
%30 = tt.addptr %29, %28 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi64, #blocked>
%31 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f16>, #blocked>) -> tensor<128x64x!tt.ptr<f16>, #blocked>
%32 = arith.extsi %18 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%33 = arith.extsi %19 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%34 = arith.extsi %20 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%35 = arith.extsi %21 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%36 = tt.expand_dims %32 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi64, #blocked>
%37 = tt.broadcast %36 : (tensor<1x64xi64, #blocked>) -> tensor<128x64xi64, #blocked>
%38 = tt.addptr %31, %37 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked>
%39 = tt.load %38 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked>
%40 = tt.splat %22 : (f32) -> tensor<128x64xf32, #blocked>
%41 = arith.extf %39 : tensor<128x64xf16, #blocked> to tensor<128x64xf32, #blocked>
%42 = arith.mulf %41, %40 : tensor<128x64xf32, #blocked>
%43 = arith.truncf %42 : tensor<128x64xf32, #blocked> to tensor<128x64xf16, #blocked>
%44 = triton_gpu.convert_layout %43 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #shared>
%45 = arith.addi %0, %c1_i32 : i32
%46 = arith.muli %45, %c128_i32 : i32
%47 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi64, #blocked1>
%48 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked1>
%49 = tt.addptr %48, %47 : tensor<64x1x!tt.ptr<f16>, #blocked1>, tensor<64x1xi64, #blocked1>
%50 = tt.broadcast %49 : (tensor<64x1x!tt.ptr<f16>, #blocked1>) -> tensor<64x64x!tt.ptr<f16>, #blocked1>
%51 = tt.splat %8 : (i64) -> tensor<1x64xi64, #blocked1>
%52 = tt.splat %10 : (i64) -> tensor<64x1xi64, #blocked>
%53 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked>
%54 = tt.broadcast %36 : (tensor<1x64xi64, #blocked>) -> tensor<64x64xi64, #blocked>
%55 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma>
%56 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%57 = tt.expand_dims %56 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x64xi32, #mma>
%58 = tt.broadcast %55 : (tensor<128x1xi32, #mma>) -> tensor<128x64xi32, #mma>
%59 = arith.cmpi sgt, %46, %c0_i32 : i32
%60 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
%61 = arith.muli %60, %51 : tensor<1x64xi64, #blocked1>
%62 = tt.broadcast %61 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
%63 = tt.addptr %50, %62 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi64, #blocked1>
%64 = triton_gpu.alloc_tensor : tensor<4x64x64xf16, #shared1>
%65 = tt.splat %59 : (i1) -> tensor<64x64xi1, #blocked1>
%66 = triton_gpu.insert_slice_async %63, %64, %c0_i32, %65 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked1> -> tensor<4x64x64xf16, #shared1>
triton_gpu.async_commit_group
%67 = tt.expand_dims %35 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
%68 = arith.muli %67, %52 : tensor<64x1xi64, #blocked>
%69 = tt.addptr %53, %68 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi64, #blocked>
%70 = tt.broadcast %69 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x64x!tt.ptr<f16>, #blocked>
%71 = tt.addptr %70, %54 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi64, #blocked>
%72 = triton_gpu.alloc_tensor : tensor<4x64x64xf16, #shared>
%73 = tt.splat %59 : (i1) -> tensor<64x64xi1, #blocked>
%74 = triton_gpu.insert_slice_async %71, %72, %c0_i32, %73 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked> -> tensor<4x64x64xf16, #shared>
triton_gpu.async_commit_group
%75 = arith.cmpi sgt, %46, %c64_i32 : i32
%76 = arith.addi %34, %cst_2 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%77 = tt.expand_dims %76 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
%78 = arith.muli %77, %51 : tensor<1x64xi64, #blocked1>
%79 = tt.broadcast %78 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
%80 = tt.addptr %50, %79 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi64, #blocked1>
%81 = tt.splat %75 : (i1) -> tensor<64x64xi1, #blocked1>
%82 = triton_gpu.insert_slice_async %80, %66, %c1_i32, %81 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked1> -> tensor<4x64x64xf16, #shared1>
triton_gpu.async_commit_group
%83 = arith.addi %35, %cst_1 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%84 = tt.expand_dims %83 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
%85 = arith.muli %84, %52 : tensor<64x1xi64, #blocked>
%86 = tt.addptr %53, %85 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi64, #blocked>
%87 = tt.broadcast %86 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x64x!tt.ptr<f16>, #blocked>
%88 = tt.addptr %87, %54 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi64, #blocked>
%89 = tt.splat %75 : (i1) -> tensor<64x64xi1, #blocked>
%90 = triton_gpu.insert_slice_async %88, %74, %c1_i32, %89 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked> -> tensor<4x64x64xf16, #shared>
triton_gpu.async_commit_group
%91 = arith.cmpi sgt, %46, %c128_i32 : i32
%92 = arith.addi %34, %cst_0 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%93 = tt.expand_dims %92 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
%94 = arith.muli %93, %51 : tensor<1x64xi64, #blocked1>
%95 = tt.broadcast %94 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
%96 = tt.addptr %50, %95 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi64, #blocked1>
%97 = tt.splat %91 : (i1) -> tensor<64x64xi1, #blocked1>
%98 = triton_gpu.insert_slice_async %96, %82, %c2_i32, %97 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked1> -> tensor<4x64x64xf16, #shared1>
triton_gpu.async_commit_group
%99 = arith.addi %35, %cst : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%100 = tt.expand_dims %99 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
%101 = arith.muli %100, %52 : tensor<64x1xi64, #blocked>
%102 = tt.addptr %53, %101 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi64, #blocked>
%103 = tt.broadcast %102 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x64x!tt.ptr<f16>, #blocked>
%104 = tt.addptr %103, %54 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi64, #blocked>
%105 = tt.splat %91 : (i1) -> tensor<64x64xi1, #blocked>
%106 = triton_gpu.insert_slice_async %104, %90, %c2_i32, %105 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked> -> tensor<4x64x64xf16, #shared>
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 4 : i32}
%107 = triton_gpu.extract_slice %98[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared1> to tensor<64x64xf16, #shared1>
%108 = triton_gpu.extract_slice %106[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared> to tensor<64x64xf16, #shared>
%109:14 = scf.for %arg21 = %c0_i32 to %46 step %c64_i32 iter_args(%arg22 = %cst_3, %arg23 = %cst_4, %arg24 = %cst_5, %arg25 = %c0_i64, %arg26 = %c0_i64, %arg27 = %98, %arg28 = %106, %arg29 = %107, %arg30 = %108, %arg31 = %c192_i64, %arg32 = %c192_i64, %arg33 = %c128_i32, %arg34 = %c3_i32, %arg35 = %c1_i32) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64, tensor<4x64x64xf16, #shared1>, tensor<4x64x64xf16, #shared>, tensor<64x64xf16, #shared1>, tensor<64x64xf16, #shared>, i64, i64, i32, i32, i32) : i32 {
%131 = tt.splat %arg21 : (i32) -> tensor<1x64xi32, #mma>
%132 = arith.addi %131, %57 : tensor<1x64xi32, #mma>
%133 = tt.broadcast %132 : (tensor<1x64xi32, #mma>) -> tensor<128x64xi32, #mma>
%134 = "triton_gpu.cmpi"(%58, %133) <{predicate = 5 : i64}> : (tensor<128x64xi32, #mma>, tensor<128x64xi32, #mma>) -> tensor<128x64xi1, #mma>
%135 = "triton_gpu.select"(%134, %cst_3, %cst_6) : (tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
%136 = triton_gpu.convert_layout %44 : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%137 = triton_gpu.convert_layout %arg29 : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%138 = tt.dot %136, %137, %135 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
%139 = "tt.reduce"(%138) <{axis = 1 : i32}> ({
^bb0(%arg36: f32, %arg37: f32):
%189 = "triton_gpu.cmpf"(%arg36, %arg37) <{predicate = 2 : i64}> : (f32, f32) -> i1
%190 = arith.select %189, %arg36, %arg37 : f32
tt.reduce.return %190 : f32
}) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%140 = "triton_gpu.cmpf"(%arg24, %139) <{predicate = 2 : i64}> : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%141 = "triton_gpu.select"(%140, %arg24, %139) : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%142 = arith.subf %arg24, %141 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%143 = tt.pure_extern_elementwise %142 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%144 = tt.expand_dims %141 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
%145 = tt.broadcast %144 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
%146 = arith.subf %138, %145 : tensor<128x64xf32, #mma>
%147 = tt.pure_extern_elementwise %146 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>
%148 = arith.mulf %arg23, %cst_4 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%149 = arith.addf %148, %143 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%150 = tt.expand_dims %149 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
%151 = tt.broadcast %150 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
%152 = arith.mulf %arg22, %151 : tensor<128x64xf32, #mma>
%153 = arith.truncf %147 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
%154 = triton_gpu.convert_layout %153 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%155 = triton_gpu.convert_layout %arg30 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%156 = tt.dot %154, %155, %152 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
%157 = arith.mulf %arg23, %143 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%158 = "tt.reduce"(%147) <{axis = 1 : i32}> ({
^bb0(%arg36: f32, %arg37: f32):
%189 = arith.addf %arg36, %arg37 : f32
tt.reduce.return %189 : f32
}) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%159 = arith.addf %157, %158 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%160 = arith.addi %arg25, %c64_i64 : i64
%161 = arith.addi %arg26, %c64_i64 : i64
%162 = arith.addi %arg33, %c64_i32 : i32
%163 = arith.cmpi slt, %162, %46 : i32
%164 = arith.remsi %arg34, %c4_i32 : i32
%165 = arith.remsi %arg35, %c4_i32 : i32
%166 = tt.splat %arg31 : (i64) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%167 = arith.addi %166, %34 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%168 = tt.expand_dims %167 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi64, #blocked1>
%169 = arith.muli %168, %51 : tensor<1x64xi64, #blocked1>
%170 = tt.broadcast %169 : (tensor<1x64xi64, #blocked1>) -> tensor<64x64xi64, #blocked1>
%171 = tt.addptr %50, %170 : tensor<64x64x!tt.ptr<f16>, #blocked1>, tensor<64x64xi64, #blocked1>
%172 = tt.splat %arg32 : (i64) -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%173 = arith.addi %172, %35 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%174 = tt.expand_dims %173 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi64, #blocked>
%175 = arith.muli %174, %52 : tensor<64x1xi64, #blocked>
%176 = tt.addptr %53, %175 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi64, #blocked>
%177 = tt.broadcast %176 : (tensor<64x1x!tt.ptr<f16>, #blocked>) -> tensor<64x64x!tt.ptr<f16>, #blocked>
%178 = tt.addptr %177, %54 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi64, #blocked>
%179 = arith.addi %arg31, %c64_i64 : i64
%180 = arith.addi %arg32, %c64_i64 : i64
%181 = tt.splat %163 : (i1) -> tensor<64x64xi1, #blocked1>
%182 = triton_gpu.insert_slice_async %171, %arg27, %164, %181 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked1> -> tensor<4x64x64xf16, #shared1>
triton_gpu.async_commit_group
%183 = tt.splat %163 : (i1) -> tensor<64x64xi1, #blocked>
%184 = triton_gpu.insert_slice_async %178, %arg28, %164, %183 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64x!tt.ptr<f16>, #blocked> -> tensor<4x64x64xf16, #shared>
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 4 : i32}
%185 = triton_gpu.extract_slice %182[%165, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared1> to tensor<64x64xf16, #shared1>
%186 = triton_gpu.extract_slice %184[%165, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<4x64x64xf16, #shared> to tensor<64x64xf16, #shared>
%187 = arith.addi %arg34, %c1_i32 : i32
%188 = arith.addi %arg35, %c1_i32 : i32
scf.yield %156, %159, %141, %160, %161, %182, %184, %185, %186, %179, %180, %162, %187, %188 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64, tensor<4x64x64xf16, #shared1>, tensor<4x64x64xf16, #shared>, tensor<64x64xf16, #shared1>, tensor<64x64xf16, #shared>, i64, i64, i32, i32, i32
}
triton_gpu.async_wait {num = 0 : i32}
%110 = triton_gpu.convert_layout %109#2 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked2>
%111 = triton_gpu.convert_layout %109#1 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked2>
%112 = tt.expand_dims %109#1 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
%113 = tt.broadcast %112 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
%114 = arith.divf %109#0, %113 : tensor<128x64xf32, #mma>
%115 = arith.muli %1, %arg20 : i32
%116 = tt.addptr %arg4, %115 : !tt.ptr<f32>, i32
%117 = tt.splat %116 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked2>
%118 = tt.addptr %117, %16 : tensor<128x!tt.ptr<f32>, #blocked2>, tensor<128xi32, #blocked2>
%119 = tt.pure_extern_elementwise %111 {libname = "libdevice", libpath = "/home/jonch/.local/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_log2f"} : (tensor<128xf32, #blocked2>) -> tensor<128xf32, #blocked2>
%120 = arith.addf %110, %119 : tensor<128xf32, #blocked2>
tt.store %118, %120 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32, #blocked2>
%121 = tt.addptr %arg5, %2 : !tt.ptr<f16>, i32
%122 = arith.extsi %arg17 : i32 to i64
%123 = arith.truncf %114 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
%124 = tt.splat %122 : (i64) -> tensor<128x1xi64, #blocked>
%125 = arith.muli %26, %124 : tensor<128x1xi64, #blocked>
%126 = tt.splat %121 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked>
%127 = tt.addptr %126, %125 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi64, #blocked>
%128 = tt.broadcast %127 : (tensor<128x1x!tt.ptr<f16>, #blocked>) -> tensor<128x64x!tt.ptr<f16>, #blocked>
%129 = tt.addptr %128, %37 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked>
%130 = triton_gpu.convert_layout %123 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #blocked>
tt.store %129, %130 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked>
tt.return
}
} Notes:
|
UpdateNew results from updating Triton. See appendix for more details ExperimentsThe following optimizations have been applied:
The following result in no appreciable difference:
The following has shown improvement but has not been adopted:
TODOs:
Appendix16x Abberant Memory Read IssueAfter some digging I isolated the issue to poor handling of transpose in between Triton did not suffer this issue as it handled the transpose as an smem layout (when loading from gmem to smem) rather than an explicit transpose of smem values. However, by upgrading Triton, this issue no longer exists (only on backdated jax-triton dependency). Ablations (in TFLOPs/s, higher is better)Note: slightly inaccurate due to (now fixed) Pallas runtime overhead. With qk dot ordering (
|
This comment was marked as resolved.
This comment was marked as resolved.
This bug previously existed and I verified it in previously nightly release of triton (20230714). However, according to new benchmarks, this bug no longer exists on Triton main. See: jax-ml/jax#17328 (comment)
Thank you for this amazing investigation! I am not actually very experienced with GPU performance and benchmarking so we really appreciate you doing this. Let's try to get your optimizations in once we update Triton internally. Does that sound good? |
Hello @sharadmv, no problem. There's a lot of data to parse through but I hope to help shed light on the state of things by generating and making available this data. |
Updates (round 2)Applied new optimizations:
Exploring:
Benchmark SettingsTotal rounds=150, Outer runs=5, GPU=RTX 4070 Laptop (sm_89) I was unable to bench flash_attn this time due to install dependency issues. However, we will use Triton as the common comparator. Summary
Have yet to explore numerical accuracy (I think in the first place numerical accuracy may be lower for Triton/Pallas) FindingsIn-loop reciprocal (highest expected stability and accuracy) - 1.21x speedupDelayed softmax reciprocal + Rely on 1/seq_len for numerical stability - Also 1.21x speedupExplanation: Faster compute = more memory / smem read bottlenecked. So we no longer see 5% speedup when reducing compute intensity further. Action: one can try to increase pipeline stages in order to better feed compute. No-Mask Loop Body=True, PV_f16_acc=True, DELAYED_SOFTMAX_NORMALIZE=True Ablation: always apply mask - slightly slower (-1%)No-Mask Loop Body=False, PV_f16_acc=True, DELAYED_SOFTMAX_NORMALIZE=True Ablation: no PV fp16 accNo-Mask Loop Body=True, PV_f16_acc=False, DELAYED_SOFTMAX_NORMALIZE=True Exploration: QK fp16 acc - 1.41x speedup |
Update 3:Explorations with more pipeline stages (in the memory bound setting) Conclusions: pipelining indeed helps the memory bound (more FP16 acc) setting. Delayed softmax does help when there is more software prefetching. Delayed Softmax, 4 stage async read pipeline - 1.22x speedupDelayed Softmax, 5 stage async read pipeline - 1.22x speedupIn-Loop Softmax, 4 stage async read pipeline - 1.17x speedupIn-Loop Softmax, 5 stage async read pipeline - 1.17x speedupExploration: QK fp16 acc, 4 stage async read pipeline - 1.44x speedupExploration: QK fp16 acc, 5 stage async read pipeline - 1.46x speedupConclusionVery excited to explore the FP8 setting with even more software pipelining! |
It was determined: for benchmarking on commercial card (e.g. RTX 4070 In particular:
My next line of investigation will be FP8 matmul. It was noted cuDNN FP8 fused attention is slow for small seq len, likely due to scale/unscale overhead. On H100, FP8 matmul + FP16 arithmetic is extremely interesting direction. |
This bug previously existed and I verified it in previously nightly release of triton (20230714). However, according to new benchmarks, this bug no longer exists on Triton main. See: jax-ml/jax#17328 (comment)
Depends on fix: #17314 for block_q != block_k
Outdated
Results: EDIT:This one line change mimicking Triton code brought the speed for Pallas for the delayed softmax way down, to be slightly below the in-loop case.
This finnicky sensitivity to the application code is undesirable, and ideally we should not carry tech debt from Triton into Pallas, as Pallas is meant to target more than just Triton.
Without exp2 optimization

With the full in-loop online softmax (both Pallas and Triton):

See update below for latest results.
CC: @sharadmv