Skip to content

Commit

Permalink
uk
Browse files Browse the repository at this point in the history
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob committed Nov 20, 2024
1 parent 4396bf1 commit 5310d5a
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 89 deletions.
55 changes: 18 additions & 37 deletions compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,13 @@ if(NOT IREE_TARGET_BACKEND_ROCM)
return()
endif()

# Check if HIP is installed on system.
# HIP is required to compile ukernels.
# TODO: We can do better than this and ensure that headers are always available.
if(NOT IREE_ROCM_PATH)
set(IREE_ROCM_PATH "/opt/rocm")
endif()
set(IREE_ROCM_VERSION "${IREE_ROCM_PATH}/include/hip/hip_version.h")
if(NOT EXISTS ${IREE_ROCM_VERSION})
message(STATUS
"hip runtime cannot be found in ${IREE_ROCM_PATH}.
Please try setting IREE_ROCM_PATH to rocm directory.
Ukernels will not be compiled.")
return()
endif()


iree_add_all_subdirs()

set(_platform_lib_reldir "iree_platform_libs/rocm")
set(_device_bc_path "${IREE_COMPILER_DYLIB_DIR}/iree_platform_libs/rocm")
set(_amd_ukernel_libs)
set(_amd_ukernel_targets)
function(iree_rocm_bitcode_library)
function(iree_amdgpu_bitcode_library)
cmake_parse_arguments(
_RULE
""
Expand All @@ -45,30 +29,27 @@ function(iree_rocm_bitcode_library)
endif()

set(_ROCM_ARCH "${_RULE_ROCM_ARCH}")
set(OPT_FLAG "-O0")
if(_ROCM_ARCH MATCHES "GFX9")
set(OPT_FLAG "-O3")
endif()
set(_COPTS
"-x" "hip"

# Compile only the device code for the target architecture.
"--offload-device-only"
"--offload-arch=${_ROCM_ARCH}"
"-x" "c"
"-Xclang" "-finclude-default-header"

# Suppress warnings about about ROCM version (we mostly don't care).
"-D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH"
"-std=c23"
"-nogpulib"
"-fno-short-wchar"

# Use the ROCM specified by the IREE cmake variable (instead of guessing
# or failing if ROCM is not on the user's path).
"--rocm-path=${IREE_ROCM_PATH}"
# Target architecture/machine.
"-target" "amdgcn-amd-amdhsa"
"-march=${_ROCM_ARCH}"
"-fgpu-rdc" # NOTE: may not be required for all targets

# Avoid linking in default libraries as we will link them at a later phase.
"-nogpulib"
# Header paths for builtins and our own includes.
"-isystem" "${IREE_CLANG_BUILTIN_HEADERS_PATH}"
"-I${IREE_SOURCE_DIR}/runtime/src"

# Only enable necessary optimizations S.T we can use -O3.
"-Xclang" "-disable-llvm-optzns"
"${OPT_FLAG}"
# Optimized.
"-fno-ident"
"-fvisibility=hidden"
"-O3"

# Object file only in bitcode format:
"-c"
Expand Down Expand Up @@ -127,7 +108,7 @@ endfunction()
# except compile-time cost, so just picked out the popular ones.
set(_ukernel_supported_chips "gfx90a" "gfx942" "gfx1030" "gfx1100")
foreach(_amd_chip ${_ukernel_supported_chips})
iree_rocm_bitcode_library(
iree_amdgpu_bitcode_library(
NAME
rocm_argmax_ukernel
ROCM_ARCH
Expand Down
138 changes: 86 additions & 52 deletions compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,50 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <float.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <stdint.h>
#include <string.h>

extern "C" __device__ __attribute__((const)) half __ockl_wfred_max_f16(half);
extern "C" __device__
__attribute__((const)) int64_t __ockl_wfred_min_i64(int64_t);
extern "C" __device__
__attribute__((const)) int32_t __ockl_wfred_min_i32(int32_t);
/*
Bits copied from HIP headers.
*/

__attribute__((const)) _Float16 __ockl_wfred_max_f16(_Float16);
__attribute__((const)) int64_t __ockl_wfred_min_i64(int64_t);
__attribute__((const)) int32_t __ockl_wfred_min_i32(int32_t);
__attribute__((const)) float __ocml_fmax_f32(float, float);
__attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16);

static inline unsigned int __lane_id() {
return __builtin_amdgcn_mbcnt_hi(-1, __builtin_amdgcn_mbcnt_lo(-1, 0));
}

static inline int __shfl_xor_i(int var, int lane_mask) {
const int width = __builtin_amdgcn_wavefrontsize();
int self = __lane_id();
int index = self ^ lane_mask;
index = index >= ((self + width) & ~(width - 1)) ? self : index;
return __builtin_amdgcn_ds_bpermute(index << 2, var);
}

static inline float __shfl_xor_f(float var, int lane_mask) {
union {
int i;
unsigned u;
float f;
} tmp;
tmp.f = var;
tmp.i = __shfl_xor_i(tmp.i, lane_mask);
return tmp.f;
}

static inline unsigned long long int __ballot(int predicate) {
return __builtin_amdgcn_uicmp(
predicate, 0, 33 /*ICMP_NE from llvm/include/llvm/IR/InstrTypes.h*/);
}

static inline unsigned int __popcll(unsigned long long int input) {
return __builtin_popcountll(input);
}

/*
Constraint/Tiling note:
Expand All @@ -21,22 +57,21 @@ only use single subgroup/warp per workgroup. This constraint is also set during
tiling phase in KernelConfig.
*/

extern "C" __device__ void __iree_uk_rocm_argmax_F32I32(float *inputBuffer,
size_t input_offset,
int32_t *outputBuffer,
size_t output_offset,
size_t reductionSize) {
uint laneID = __builtin_amdgcn_workitem_id_x();
void __iree_uk_rocm_argmax_F32I32(float *inputBuffer, size_t input_offset,
int32_t *outputBuffer, size_t output_offset,
size_t reductionSize) {
const int warpSize = __builtin_amdgcn_wavefrontsize();
uint32_t laneID = __builtin_amdgcn_workitem_id_x();
// Set identity value to handle problem non divisible by subgroupSize.
float laneMax =
laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID];
int32_t laneResult = laneID;

// NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
// inaccuracy.
uint numBatches = (reductionSize + warpSize - 1) / warpSize;
uint32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
for (int i = 1; i < numBatches; ++i) {
uint idx = warpSize * i + laneID;
uint32_t idx = warpSize * i + laneID;
float newIn =
idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx];
if (newIn == laneMax)
Expand All @@ -50,7 +85,7 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I32(float *inputBuffer,
// https://github.com/iree-org/iree/issues/16112.
float wgMax = laneMax;
for (int i = 1; i < warpSize; i *= 2) {
wgMax = __ocml_fmax_f32(__shfl_xor(wgMax, i), wgMax);
wgMax = __ocml_fmax_f32(__shfl_xor_f(wgMax, i), wgMax);
}
// Check if there are multiple max value holders.
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
Expand All @@ -68,22 +103,21 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I32(float *inputBuffer,
outputBuffer[output_offset] = laneResult;
}

extern "C" __device__ void __iree_uk_rocm_argmax_F32I64(float *inputBuffer,
size_t input_offset,
int64_t *outputBuffer,
size_t output_offset,
size_t reductionSize) {
uint laneID = __builtin_amdgcn_workitem_id_x();
void __iree_uk_rocm_argmax_F32I64(float *inputBuffer, size_t input_offset,
int64_t *outputBuffer, size_t output_offset,
size_t reductionSize) {
const int warpSize = __builtin_amdgcn_wavefrontsize();
uint32_t laneID = __builtin_amdgcn_workitem_id_x();
// Set identity value to handle problem non divisible by subgroupSize.
float laneMax =
laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID];
int64_t laneResult = laneID;

// NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
// inaccuracy.
uint numBatches = (reductionSize + warpSize - 1) / warpSize;
uint32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
for (int i = 1; i < numBatches; ++i) {
uint idx = warpSize * i + laneID;
uint32_t idx = warpSize * i + laneID;
float newIn =
idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx];
if (newIn == laneMax)
Expand All @@ -97,7 +131,7 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I64(float *inputBuffer,
// https://github.com/iree-org/iree/issues/16112.
float wgMax = laneMax;
for (int i = 1; i < warpSize; i *= 2) {
wgMax = __ocml_fmax_f32(__shfl_xor(wgMax, i), wgMax);
wgMax = __ocml_fmax_f32(__shfl_xor_f(wgMax, i), wgMax);
}
// Check if there are multiple max value holders.
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
Expand All @@ -115,31 +149,30 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I64(float *inputBuffer,
outputBuffer[output_offset] = laneResult;
}

extern "C" __device__ void __iree_uk_rocm_argmax_F16I32(half *inputBuffer,
size_t input_offset,
int32_t *outputBuffer,
size_t output_offset,
size_t reductionSize) {
half NEG_F16_MAX = __float2half(-65504.0f);
uint laneID = __builtin_amdgcn_workitem_id_x();
void __iree_uk_rocm_argmax_F16I32(_Float16 *inputBuffer, size_t input_offset,
int32_t *outputBuffer, size_t output_offset,
size_t reductionSize) {
const int warpSize = __builtin_amdgcn_wavefrontsize();
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
uint32_t laneID = __builtin_amdgcn_workitem_id_x();
// Set identity value to handle problem non divisible by subgroupSize.
half laneMax = laneID >= reductionSize ? NEG_F16_MAX
: inputBuffer[input_offset + laneID];
_Float16 laneMax = laneID >= reductionSize
? NEG_F16_MAX
: inputBuffer[input_offset + laneID];
int32_t laneResult = laneID;

uint numBatches = (reductionSize + warpSize - 1) / warpSize;
uint32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
for (int i = 1; i < numBatches; ++i) {
uint idx = warpSize * i + laneID;
half newIn =
uint32_t idx = warpSize * i + laneID;
_Float16 newIn =
idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx];
if (newIn == laneMax)
continue;
laneMax = __ocml_fmax_f16(newIn, laneMax);
laneResult = newIn == laneMax ? idx : laneResult;
}

// Final reduction with one subgroup
half wgMax = __ockl_wfred_max_f16(laneMax);
_Float16 wgMax = __ockl_wfred_max_f16(laneMax);
// Check if there are multiple max value holders.
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
Expand All @@ -148,6 +181,7 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F16I32(half *inputBuffer,
outputBuffer[output_offset] = laneResult;
return;
}

// if there are multiple max value holder, find smallest index (argmax
// semantics).
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
Expand All @@ -156,22 +190,22 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F16I32(half *inputBuffer,
outputBuffer[output_offset] = laneResult;
}

extern "C" __device__ void __iree_uk_rocm_argmax_F16I64(half *inputBuffer,
size_t input_offset,
int64_t *outputBuffer,
size_t output_offset,
size_t reductionSize) {
half NEG_F16_MAX = __float2half(-65504.0f);
uint laneID = __builtin_amdgcn_workitem_id_x();
void __iree_uk_rocm_argmax_F16I64(_Float16 *inputBuffer, size_t input_offset,
int64_t *outputBuffer, size_t output_offset,
size_t reductionSize) {
const int warpSize = __builtin_amdgcn_wavefrontsize();
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
uint32_t laneID = __builtin_amdgcn_workitem_id_x();
// Set identity value to handle problem non divisible by subgroupSize.
half laneMax = laneID >= reductionSize ? NEG_F16_MAX
: inputBuffer[input_offset + laneID];
_Float16 laneMax = laneID >= reductionSize
? NEG_F16_MAX
: inputBuffer[input_offset + laneID];
int64_t laneResult = laneID;

uint numBatches = (reductionSize + warpSize - 1) / warpSize;
uint32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
for (int i = 1; i < numBatches; ++i) {
uint idx = warpSize * i + laneID;
half newIn =
uint32_t idx = warpSize * i + laneID;
_Float16 newIn =
idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx];
if (newIn == laneMax)
continue;
Expand All @@ -180,7 +214,7 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F16I64(half *inputBuffer,
}

// Final reduction with one subgroup
half wgMax = __ockl_wfred_max_f16(laneMax);
_Float16 wgMax = __ockl_wfred_max_f16(laneMax);
// Check if there are multiple max value holders.
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
Expand Down

0 comments on commit 5310d5a

Please sign in to comment.