diff --git a/lib/mtl/device.jl b/lib/mtl/device.jl index 2931597fc..0747006e4 100644 --- a/lib/mtl/device.jl +++ b/lib/mtl/device.jl @@ -91,7 +91,7 @@ MTLDevice(i::Integer) = devices()[i] # family # -export supports_family, is_m3, is_m2, is_m1 +export supports_family, is_m4, is_m3, is_m2, is_m1 @cenum MTLGPUFamily::NSInteger begin MTLGPUFamilyMetal3 = 5001 # Metal 3 support @@ -121,5 +121,7 @@ is_m1(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple7) && !supports_family(dev, MTLGPUFamilyApple8) is_m2(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple8) && !supports_family(dev, MTLGPUFamilyApple9) -is_m3(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) -is_m4(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) +is_m3(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) && + occursin("M3", String(dev.name)) +is_m4(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) && + occursin("M4", String(dev.name))