diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index c12fe526ca..062ba34ee7 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -14,7 +14,8 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) # Maps PTX types to Julia fragment types @@ -23,7 +24,8 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) # Maps matrix & PTX types to fragment sizes @@ -40,43 +42,51 @@ const map_frag_sizes = Dict( "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, + + "a.f64.m8n8k4" => 1, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, "b.u8.m32n8k16" => 1, - + "b.s8.m16n16k16" => 2, "b.s8.m8n32k16" => 4, "b.s8.m32n8k16" => 1, - + "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, + + "b.f64.m8n8k4" => 1, # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, "c.s32.m32n8k16" => 8, - + "c.f16.m16n16k16" => 4, "c.f16.m8n32k16" => 4, "c.f16.m32n8k16" => 4, - + "c.f32.m16n16k16" => 8, "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, + + "c.f64.m8n8k4" => 1, # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, "d.s32.m32n8k16" => 8, - + "d.f16.m16n16k16" => 4, "d.f16.m8n32k16" => 4, "d.f16.m32n8k16" => 4, - + "d.f32.m16n16k16" => 8, "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, - ) + + "d.f64.m8n8k4" => 1, + ) # Maps PTX AS to CUDA.AS const map_ptx_as_to_as_ty = Dict( @@ -87,6 +97,10 @@ const map_ptx_as_to_as_ty = Dict( # Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type +# Double-Precision Floating Point +const ldst_double_ab_ops = [(8,8,4)], ["a", "b"], ["f64"] +const ldst_double_cd_ops = [(8,8,4)], ["c", "d"], ["f64"] +const wmma_double_ops = [(8,8,4)], ["f64"], ["f64"], ["f64"] # Half-Precision Floating Point const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"] const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"] @@ -97,11 +111,12 @@ const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops) -const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) + ldst_int_ab_ops, ldst_int_cd_ops, + ldst_double_ab_ops, ldst_double_cd_ops) +const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_double_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8,8,4)] ################################################################################ # HELPER FUNCTIONS @@ -309,7 +324,7 @@ for ops in all_wmma_ops, # Name of the LLVM intrinsic # If integer/sub-byte/bit A/B types, name is determined by A/B types - if d_elem_type == "s32" + if d_elem_type == "s32" || d_elem_type == "f64" llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type" # Name of the Julia wrapper function func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))