import jax import jax.numpy as jnp import jax_triton as jt import triton import triton.language as tl from .matmul_perf_model import early_config_prune, estimate_matmul_time # relative path to matmul_perf_model (https://github.com/openai/triton/blob/main/python/triton/ops/matmul_perf_model.py) ################## # OpenAI Matmul kernel # (https://github.com/openai/triton/blob/main/python/triton/ops/matmul.py#L63) ################## ### Kernel without metaparams config @triton.jit def matmul_kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, dot_out_dtype: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) grid_m = tl.cdiv(M, BLOCK_M) grid_n = tl.cdiv(N, BLOCK_N) # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) # pointers A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): if EVEN_K: a = tl.load(A) b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) acc += tl.dot(a, b, out_dtype=dot_out_dtype) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk acc = acc.to(C.dtype.element_ty) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm < M)[:, None] & (rn < N)[None, :] # handles write-back with reduction-splitting if SPLIT_K == 1: tl.store(C, acc, mask=mask) else: tl.atomic_add(C, acc, mask=mask) ### Kernel with metaparams configs passed via autotune @triton.autotune( configs = [triton.Config( { 'BLOCK_M': 16, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, }, num_stages=2, num_warps=2)], key=['M', 'N','K'], ) @triton.heuristics( { 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, } ) @triton.jit def matmul_kernel_autotuned(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, dot_out_dtype: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) grid_m = tl.cdiv(M, BLOCK_M) grid_n = tl.cdiv(N, BLOCK_N) # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) # pointers A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): if EVEN_K: a = tl.load(A) b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) acc += tl.dot(a, b, out_dtype=dot_out_dtype) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk acc = acc.to(C.dtype.element_ty) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm < M)[:, None] & (rn < N)[None, :] # handles write-back with reduction-splitting if SPLIT_K == 1: tl.store(C, acc, mask=mask) else: tl.atomic_add(C, acc, mask=mask) ################## # Jax-Triton kernels ################## @jax.jit def matmul_kernel_jax(a:jnp.ndarray, b:jnp.ndarray): # 1) Define outputs shape M, K_A = a.shape K_B, N = b.shape assert K_A == K_B, 'incompatible dimensions' K = K_A outputs_shape = (M, N) outputs_dtype = a.dtype # 2) Define params, grid BLOCK_M = 16 BLOCK_N = 16 BLOCK_K, SPLIT_K = 16, 1 EVEN_K = K % (BLOCK_K * SPLIT_K) GROUP_M = 8 grid = lambda META: ( triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'], ) out_shape = jax.ShapeDtypeStruct(shape=outputs_shape, dtype=outputs_dtype) return jt.triton_call( a, b, M=M, N=N, K=K, stride_am=jt.strides_from_shape(a.shape)[0], stride_ak=jt.strides_from_shape(a.shape)[1], stride_bk=jt.strides_from_shape(b.shape)[0], stride_bn=jt.strides_from_shape(b.shape)[1], stride_cm=jt.strides_from_shape(outputs_shape)[0], stride_cn=jt.strides_from_shape(outputs_shape)[1], kernel=matmul_kernel, out_shape=out_shape, grid=grid, input_output_aliases=None, zeroed_outputs=(0,), num_warps=4, num_stages=2, debug=False, dot_out_dtype=tl.float32, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_M=GROUP_M, SPLIT_K=SPLIT_K, EVEN_K=EVEN_K ) @jax.jit def matmul_kernel_jax_autotuned(a:jnp.ndarray, b:jnp.ndarray): # 1) Define outputs shape M, K_A = a.shape K_B, N = b.shape assert K_A == K_B, 'incompatible dimensions' K = K_A outputs_shape = (M, N) outputs_dtype = a.dtype # 2) Define grid GROUP_M = 8 grid = lambda META: ( triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'], ) # 3) Define parameters metaparams = { 'dot_out_dtype': tl.float32, 'GROUP_M': GROUP_M, } out_shape = jax.ShapeDtypeStruct(shape=outputs_shape, dtype=outputs_dtype) return jt.triton_call( a, b, M=M, N=N, K=K, stride_am=jt.strides_from_shape(a.shape)[0], stride_ak=jt.strides_from_shape(a.shape)[1], stride_bk=jt.strides_from_shape(b.shape)[0], stride_bn=jt.strides_from_shape(b.shape)[1], stride_cm=jt.strides_from_shape(outputs_shape)[0], stride_cn=jt.strides_from_shape(outputs_shape)[1], kernel=matmul_kernel_autotuned, out_shape=out_shape, grid=grid, input_output_aliases=None, zeroed_outputs=(), num_warps=4, num_stages=2, debug=False, **metaparams ) if __name__ == "__main__": h, w = 4,4 # Fails for i in range(4): k1, k2 = jax.random.split(jax.random.PRNGKey(i)) a_val = jax.random.normal(k1, (h, w), dtype=jnp.float16) b_val = jax.random.normal(k2, (h, w), dtype=jnp.float16) jax_triton_result = (matmul_kernel_jax(a_val, b_val)) jax_result = a_val @ b_val if not jnp.isclose(jax_triton_result, jax_result, atol=1e-2).all(): print(f"FAILED:\ntriton={jax_triton_result}, jax={jax_result}") else: print("PASSED") print("#############\n#############") # Passes for i in range(4): k1, k2 = jax.random.split(jax.random.PRNGKey(i)) a_val = jax.random.normal(k1, (h, w), dtype=jnp.float16) b_val = jax.random.normal(k2, (h, w), dtype=jnp.float16) jax_triton_result = (matmul_kernel_jax_autotuned(a_val, b_val)) jax_result = a_val @ b_val if not jnp.isclose(jax_triton_result, jax_result, atol=1e-2).all(): print(f"FAILED:\ntriton={jax_triton_result}, jax={jax_result}") else: print("PASSED") from IPython import embed; embed()