From 935928a6f76e98d11da4871225756bdfe14069be Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Thu, 3 Mar 2022 10:43:48 +0100 Subject: [PATCH 1/2] naive attempt to add f64 wmma support --- src/device/intrinsics/wmma.jl | 41 ++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index c12fe526ca..8658b56d73 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -14,7 +14,8 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) # Maps PTX types to Julia fragment types @@ -23,7 +24,8 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) # Maps matrix & PTX types to fragment sizes @@ -40,43 +42,51 @@ const map_frag_sizes = Dict( "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, + + "a.f64.m8n8k4" => 8, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, "b.u8.m32n8k16" => 1, - + "b.s8.m16n16k16" => 2, "b.s8.m8n32k16" => 4, "b.s8.m32n8k16" => 1, - + "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, + + "b.f64.m8n8k4" => 8, # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, "c.s32.m32n8k16" => 8, - + "c.f16.m16n16k16" => 4, "c.f16.m8n32k16" => 4, "c.f16.m32n8k16" => 4, - + "c.f32.m16n16k16" => 8, "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, + + "c.f64.m8n8k4" => 8, # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, "d.s32.m32n8k16" => 8, - + "d.f16.m16n16k16" => 4, "d.f16.m8n32k16" => 4, "d.f16.m32n8k16" => 4, - + "d.f32.m16n16k16" => 8, "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, - ) + + "d.f64.m8n8k4" => 8, + ) # Maps PTX AS to CUDA.AS const map_ptx_as_to_as_ty = Dict( @@ -87,6 +97,10 @@ const map_ptx_as_to_as_ty = Dict( # Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type +# Double-Precision Floating Point +const ldst_double_ab_ops = [(8,8,4)], ["a", "b"], ["f64"] +const ldst_double_cd_ops = [(8,8,4)], ["c", "d"], ["f64"] +const wmma_double_ops = [(8,8,4)], ["f64"], ["f64"], ["f64"] # Half-Precision Floating Point const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"] const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"] @@ -97,11 +111,12 @@ const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops) -const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) + ldst_int_ab_ops, ldst_int_cd_ops, + ldst_double_ab_ops, ldst_double_cd_ops) +const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_double_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8,8,4)] ################################################################################ # HELPER FUNCTIONS @@ -309,7 +324,7 @@ for ops in all_wmma_ops, # Name of the LLVM intrinsic # If integer/sub-byte/bit A/B types, name is determined by A/B types - if d_elem_type == "s32" + if d_elem_type == "s32" || d_elem_type == "f64" llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type" # Name of the Julia wrapper function func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_")) From d067d84b5842fd62a7a86ddd670e9fe824e00b82 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Fri, 4 Mar 2022 16:25:45 +0100 Subject: [PATCH 2/2] fix wmma f64 fragment sizes --- src/device/intrinsics/wmma.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index 8658b56d73..062ba34ee7 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -43,7 +43,7 @@ const map_frag_sizes = Dict( "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, - "a.f64.m8n8k4" => 8, + "a.f64.m8n8k4" => 1, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -57,7 +57,7 @@ const map_frag_sizes = Dict( "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, - "b.f64.m8n8k4" => 8, + "b.f64.m8n8k4" => 1, # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, @@ -71,7 +71,7 @@ const map_frag_sizes = Dict( "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, - "c.f64.m8n8k4" => 8, + "c.f64.m8n8k4" => 1, # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, @@ -85,7 +85,7 @@ const map_frag_sizes = Dict( "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, - "d.f64.m8n8k4" => 8, + "d.f64.m8n8k4" => 1, ) # Maps PTX AS to CUDA.AS