Skip to content

Commit

Permalink
rocblas-remove-int8x4-uses(01) Removed support of INT8x4 type from th…
Browse files Browse the repository at this point in the history
…e library (except miopen.h, some tests and driver)
  • Loading branch information
atamazov committed Oct 4, 2023
1 parent 3413d2d commit d14242f
Show file tree
Hide file tree
Showing 44 changed files with 150 additions and 261 deletions.
11 changes: 5 additions & 6 deletions include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,11 @@ MIOPEN_DECLARE_OBJECT(miopenReduceTensorDescriptor);
*/
typedef enum
{
miopenHalf = 0, /*!< 16-bit floating point (Fully supported) */
miopenFloat = 1, /*!< 32-bit floating point (Fully supported) */
miopenInt32 = 2, /*!< 32-bit int point (Partially supported) */
miopenInt8 = 3, /*!< 8-bit int point (Partially supported) */
miopenInt8x4 =
4, /*!< Pack of four 8-bit int points in NCHW_VECT_C format (Partially supported) */
miopenHalf = 0, /*!< 16-bit floating point (Fully supported) */
miopenFloat = 1, /*!< 32-bit floating point (Fully supported) */
miopenInt32 = 2, /*!< 32-bit int point (Partially supported) */
miopenInt8 = 3, /*!< 8-bit int point (Partially supported) */
miopenInt8x4 = 4, /*!< Pack of four Int8 in NCHW_VECT_C format (Support discontinued) */
miopenBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction)
(Partially supported) */
miopenDouble = 6, /*!< 64-bit floating point (Partially supported) */
Expand Down
2 changes: 1 addition & 1 deletion src/check_numerics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ std::string GetKernelName(miopenDataType_t data_type)
case miopenBFloat8: return {"check_numerics_bf8"};
case miopenInt32:
case miopenInt8:
case miopenInt8x4:
case miopenInt8x4: // Support discontinued.
case miopenDouble:
default: return {""};
}
Expand Down
2 changes: 1 addition & 1 deletion src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ ConvolutionDescriptor::GetForwardOutputTensorWithLayout(const TensorDescriptor&
std::vector<std::size_t> out_strides;
tensor_layout_to_strides(
out_lens, default_layout, yLayout, xDesc.GetVectorLength(), out_strides);
return {(xDesc.GetType() == miopenInt8 || xDesc.GetType() == miopenInt8x4
return {(xDesc.GetType() == miopenInt8
? (yType)
: xDesc.GetType()), // TODO: This function overrides the output type with
// essentially the input which is incorrect.
Expand Down
52 changes: 15 additions & 37 deletions src/gemm_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@
/// "disabled expansion of recursive macro" injected by rocblas headers.
#define AVOID_ROCBLAS_WRAPPERS_204 (MIOPEN_ROCBLAS_VERSION_FLAT >= 2004000)

/// Maintain API compatibility with various rocBLAS version
#define USE_GEMM_FLAGS_PACK_INT8X4 \
((MIOPEN_ROCBLAS_VERSION_FLAT >= 2038000) && (MIOPEN_ROCBLAS_VERSION_FLAT < 4000000))

/// Maintain API compatibility for versions not supporting FP16 alternate implementations
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (MIOPEN_ROCBLAS_VERSION_FLAT >= 2043000)
/// Some 2.42 versions have rocblas_gemm_flags_fp16_alt_impl, but
Expand Down Expand Up @@ -109,7 +105,7 @@ static inline rocblas_datatype rocBlasComputeType(const miopen::GemmDescriptor&
{
// Complex compute types are only supported in newer version of the API
assert(desc.dataType == desc.a_cast_type && desc.dataType == desc.b_cast_type);
if(desc.dataType == miopenInt8 || desc.dataType == miopenInt8x4)
if(desc.dataType == miopenInt8)
return rocblas_datatype::rocblas_datatype_i32_r;
else
return rocblas_datatype::rocblas_datatype_f32_r;
Expand Down Expand Up @@ -439,7 +435,6 @@ miopenStatus_t CallGemm(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -471,12 +466,7 @@ miopenStatus_t CallGemm(const Handle& handle,
rocBlasComputeType(gemm_desc), // rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
break;
case miopenInt32: break;
Expand Down Expand Up @@ -620,9 +610,9 @@ miopenStatus_t CallGemm(const Handle& handle,
};
break;

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}
Expand Down Expand Up @@ -692,7 +682,6 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -728,12 +717,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
break;
case miopenInt32: break;
Expand Down Expand Up @@ -892,10 +876,10 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
break;
}

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
}
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}

Expand Down Expand Up @@ -967,7 +951,6 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -1001,12 +984,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
rocBlasComputeType(gemm_desc), // rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
}
break;
Expand Down Expand Up @@ -1162,10 +1140,10 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
break;
}

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
}
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}

Expand Down Expand Up @@ -1195,7 +1173,7 @@ GemmDescriptor CreateGemmDescriptorConvFwd(const TensorDescriptor& wDesc,
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#endif

Expand Down Expand Up @@ -1350,7 +1328,7 @@ GemmDescriptor CreateGemmDescriptorConvCNHWFwd(const TensorDescriptor& wDesc,
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#endif

Expand Down Expand Up @@ -1454,7 +1432,7 @@ GemmDescriptor CreateGemmStridedBatchedDescriptorConv1x1Fwd(const TensorDescript
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#else
(void)yDesc;
Expand Down
6 changes: 5 additions & 1 deletion src/hip/batched_transpose_sol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ BatchedTransposeSolution::BatchedTransposeSolution(const ExecutionContext& ctx,
uint32_t width_)
: data_type(data_type_), batch(batch_), height(height_), width(width_)
{
if(data_type == miopenInt8x4 || data_type == miopenDouble)
if(!(data_type == miopenHalf //
|| data_type == miopenFloat //
|| data_type == miopenInt32 //
|| data_type == miopenInt8 //
|| data_type == miopenBFloat16))
MIOPEN_THROW("These data type are not supported");
num_cu = ctx.GetStream().GetMaxComputeUnits();
std::size_t data_size = miopen::GetTypeSize(data_type);
Expand Down
12 changes: 6 additions & 6 deletions src/include/miopen/datatype.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ inline std::string GetDataType(miopenDataType_t type)
type_str = "bfloat16";
}
break;
case miopenInt8x4:
case miopenInt8x4: {
type_str = "UNSUPPORTED_TYPE";
}
break;
case miopenInt8: {
type_str = "int8_t";
}
Expand Down Expand Up @@ -137,7 +140,6 @@ inline KernelBuildParameters GetDataTypeKBP(miopenDataType_t type)
int use_fp16x8 = 0;
int use_fp32 = 0;
int use_int8 = 0;
int use_int8x4 = 0;
int use_int32 = 0;
int use_bfp16 = 0;
int use_fp64 = 0;
Expand All @@ -150,15 +152,14 @@ inline KernelBuildParameters GetDataTypeKBP(miopenDataType_t type)
case miopenHalf: use_fp16 = 1; break;
case miopenFloat: use_fp32 = 1; break;
case miopenInt8: use_int8 = 1; break;
case miopenInt8x4: use_int8x4 = 1; break;
case miopenBFloat16: use_bfp16 = 1; break;
case miopenInt32: use_int32 = 1; break;
case miopenDouble: use_fp64 = 1; break;
case miopenFloat8: use_fp8 = 1; break;
case miopenBFloat8: use_bfp8 = 1; break;
case miopenInt8x4: // fallthrough
default:
MIOPEN_THROW(
"Only float, half, bfloat16, int8, int8x4, float8, bfloat8 data type is supported.");
MIOPEN_THROW("Only float, half, bfloat16, int8, float8, bfloat8 data types are supported.");
break;
}

Expand All @@ -168,7 +169,6 @@ inline KernelBuildParameters GetDataTypeKBP(miopenDataType_t type)
{"MIOPEN_USE_FP16x8", use_fp16x8},
{"MIOPEN_USE_FP32", use_fp32},
{"MIOPEN_USE_INT8", use_int8},
{"MIOPEN_USE_INT8x4", use_int8x4},
{"MIOPEN_USE_BFP16", use_bfp16},
{"MIOPEN_USE_INT32", use_int32},
{"MIOPEN_USE_RNE_BFLOAT16", use_rne_bfloat16},
Expand Down
4 changes: 2 additions & 2 deletions src/include/miopen/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ inline std::size_t GetTypeSize(miopenDataType_t d)
case miopenFloat: return 4;
case miopenHalf:
case miopenBFloat16: return 2;
case miopenInt8x4:
case miopenInt8x4: break;
case miopenInt8:
case miopenFloat8:
case miopenBFloat8: return 1;
case miopenDouble: return 8;
}
MIOPEN_THROW("Unknown data type");
MIOPEN_THROW("Unknown or unsupported data type");
}

template <class X, class Y>
Expand Down
2 changes: 1 addition & 1 deletion src/include/miopen/visit_float.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ void visit_float(miopenDataType_t t, F f)
}
case miopenFloat8:
case miopenBFloat8:
case miopenInt8x4:
case miopenInt8: {
f(as_float<int8_t>{});
break;
Expand All @@ -92,6 +91,7 @@ void visit_float(miopenDataType_t t, F f)
f(as_float<double>{});
break;
}
case miopenInt8x4: MIOPEN_THROW("miopenInt8x4: Support discontinued.");
}
}

Expand Down
6 changes: 0 additions & 6 deletions src/kernels/MIOpenIm2d2Col.cl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#ifndef MIOPEN_USE_INT32
#define MIOPEN_USE_INT32 0
#endif
Expand All @@ -58,8 +54,6 @@

#if MIOPEN_USE_INT8 || MIOPEN_USE_FP8 || MIOPEN_USE_BFP8
typedef char data_t;
#elif MIOPEN_USE_INT8x4
typedef uint data_t;
#elif MIOPEN_USE_INT32
typedef int data_t;
#elif(MIOPEN_USE_FP16 || MIOPEN_USE_BFP16)
Expand Down
6 changes: 0 additions & 6 deletions src/kernels/MIOpenIm3d2Col.cl
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,12 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#ifndef MIOPEN_USE_INT32
#define MIOPEN_USE_INT32 0
#endif

#if MIOPEN_USE_INT8
typedef char data_t;
#elif MIOPEN_USE_INT8x4
typedef uint data_t;
#elif MIOPEN_USE_INT32
typedef int data_t;
#elif(MIOPEN_USE_FP16 || MIOPEN_USE_BFP16)
Expand Down
6 changes: 1 addition & 5 deletions src/kernels/MIOpenSubTensorOpWithScalarKernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,13 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#ifndef MIOPEN_USE_INT32
#define MIOPEN_USE_INT32 0
#endif

#include "float_types.h"

#if MIOPEN_USE_INT8 == 1 || MIOPEN_USE_INT8x4 == 1
#if MIOPEN_USE_INT8 == 1
#define _FLOAT char
#endif

Expand Down
6 changes: 1 addition & 5 deletions src/kernels/MIOpenSubTensorOpWithSubTensorKernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#if MIOPEN_USE_INT8 == 1 || MIOPEN_USE_INT8x4 == 1
#if MIOPEN_USE_INT8 == 1
#define _FLOAT char
#ifndef FLT_MAX
#define MAX_VAL 127 /* max value */
Expand Down
6 changes: 1 addition & 5 deletions src/kernels/MIOpenSubTensorOpWithTransformKernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#if MIOPEN_USE_INT8 == 1 || MIOPEN_USE_INT8x4 == 1
#if MIOPEN_USE_INT8 == 1
#define _FLOAT char
#ifndef FLT_MAX
#define MAX_VAL 127 /* max value */
Expand Down
6 changes: 0 additions & 6 deletions src/kernels/MIOpenUtilKernels4.cl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#ifndef MIOPEN_USE_INT32
#define MIOPEN_USE_INT32 0
#endif
Expand All @@ -58,8 +54,6 @@

#if MIOPEN_USE_INT8 || MIOPEN_USE_FP8 || MIOPEN_USE_BFP8
typedef char data_t;
#elif MIOPEN_USE_INT8x4
typedef uint data_t;
#elif MIOPEN_USE_INT32
typedef int data_t;
#elif(MIOPEN_USE_FP16 || MIOPEN_USE_BFP16)
Expand Down
Loading

0 comments on commit d14242f

Please sign in to comment.