Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MIOPEN_BETA_API defines around f8 #2430

Merged
merged 21 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions driver/conv_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ static inline miopenDataType_t DataTypeFromShortString(const std::string& type)
static const std::unordered_map<std::string, miopenDataType_t> conv_map = {
{"fp32", miopenFloat},
{"fp16", miopenHalf},
{"bf16", miopenBFloat16},
{"fp8", miopenFloat8},
{"bf8", miopenBFloat8}};
{"bf16", miopenBFloat16}};
#ifdef MIOPEN_BETA_API
conv_map.insert({"fp8", miopenFloat8});
conv_map.insert({"bf8", miopenBFloat8});
#endif

const auto res = conv_map.find(type);
if(res != conv_map.end())
Expand Down Expand Up @@ -688,17 +690,21 @@ int ConvDriver<Tgpu, Tref>::GetandSetData()
std::vector<int> wei_len = GetWeightTensorLengthsFromCmdLine();

SetTensorNd(inputTensor, in_len, inflags.GetValueStr("in_layout"), data_type);
#ifdef MIOPEN_BETA_API
if(inflags.GetValueStr("in_cast_type") != "-1")
{
const auto in_cast_type = DataTypeFromShortString(inflags.GetValueStr("in_cast_type"));
miopenSetTensorCastType(inputTensor, in_cast_type);
}
#endif
SetTensorNd(weightTensor, wei_len, inflags.GetValueStr("fil_layout"), data_type);
#ifdef MIOPEN_BETA_API
if(inflags.GetValueStr("wei_cast_type") != "-1")
{
const auto in_cast_type = DataTypeFromShortString(inflags.GetValueStr("wei_cast_type"));
miopenSetTensorCastType(weightTensor, in_cast_type);
}
#endif

if(inflags.GetValueInt("tensor_vect") == 1 && data_type == miopenInt8)
{
Expand Down Expand Up @@ -730,11 +736,13 @@ int ConvDriver<Tgpu, Tref>::GetandSetData()
miopenDataType_t y_type =
(data_type == miopenInt8 || data_type == miopenInt8x4) ? miopenInt32 : data_type;
SetTensorNd(outputTensor, out_len, inflags.GetValueStr("out_layout"), y_type);
#ifdef MIOPEN_BETA_API
if(inflags.GetValueStr("out_cast_type") != "-1")
{
const auto out_cast_type = DataTypeFromShortString(inflags.GetValueStr("out_cast_type"));
miopenSetTensorCastType(outputTensor, out_cast_type);
}
#endif

if(inflags.GetValueInt("bias") != 0)
{
Expand Down Expand Up @@ -1247,7 +1255,11 @@ int ConvDriver<Tgpu, Tref>::AllocateBuffersAndCopy()
bool is_int8 = data_type == miopenInt8 || data_type == miopenInt8x4;
// Data generated for very low precision types follows the same constraints whether its fp8,
// bfp8 or even if the interim tensors are being casted
#ifdef MIOPEN_BETA_API
bool is_fp8 = data_type == miopenFloat8 || data_type == miopenBFloat8 || TensorsCasted();
#else
bool is_fp8 = false;
#endif
size_t in_sz = GetTensorSize(inputTensor);
size_t wei_sz = GetTensorSize(weightTensor);
size_t out_sz = GetTensorSize(outputTensor);
Expand Down Expand Up @@ -1619,9 +1631,12 @@ bool ConvDriver<Tgpu, Tref>::UseGPUReference()
if(!miopen::IsDisabled(MIOPEN_DRIVER_USE_GPU_REFERENCE{}))
{
if((miopen_type<Tref>{} == miopenFloat &&
(miopen_type<Tgpu>{} == miopenFloat || miopen_type<Tgpu>{} == miopenHalf ||
miopen_type<Tgpu>{} == miopenBFloat16 || miopen_type<Tgpu>{} == miopenFloat8 ||
miopen_type<Tgpu>{} == miopenBFloat8)) ||
(miopen_type<Tgpu>{} == miopenFloat || miopen_type<Tgpu>{} == miopenHalf ||
miopen_type<Tgpu>{} == miopenBFloat16
#ifdef MIOPEN_BETA_API
|| miopen_type<Tgpu>{} == miopenFloat8 || miopen_type<Tgpu>{} == miopenBFloat8
#endif
)) ||
(miopen_type<Tref>{} == miopenInt32 && miopen_type<Tgpu>{} == miopenInt8))
return true;
else
Expand Down
2 changes: 2 additions & 0 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ inline void Driver::InitDataType<bfloat16>()
{
data_type = miopenBFloat16;
}
#ifdef MIOPEN_BETA_API
template <>
inline void Driver::InitDataType<float8>()
{
Expand All @@ -271,6 +272,7 @@ inline void Driver::InitDataType<bfloat8>()
{
data_type = miopenBFloat8;
}
#endif
// "std::is_same<Tgpu, float>{}" used to avoid "static_assert" compilation error,
// which occurs when the condition does not depend in any way on the template parameters.
template <typename Tgpu>
Expand Down
2 changes: 2 additions & 0 deletions driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ int main(int argc, char* argv[])
{
drv = new ConvDriver<int8_t, int32_t>();
}
#ifdef MIOPEN_BETA_API
else if(base_arg == "convfp8")
{
drv = new ConvDriver<float8, float>();
Expand All @@ -92,6 +93,7 @@ int main(int argc, char* argv[])
{
drv = new ConvDriver<bfloat8, float>();
}
#endif
else if(base_arg == "CBAInfer")
{
drv = new CBAInferFusionDriver<float, double>();
Expand Down
15 changes: 14 additions & 1 deletion include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ typedef enum
miopenStatusVersionMismatch = 10, /*!< Version mismatch of the supplied binary data argment. */
} miopenStatus_t;

#ifdef MIOPEN_BETA_API
typedef enum
{
miopenF8RoundingModeStandard = 0,
miopenF8RoundingModeStochastic = 1,
} miopenF8RoundingMode_t;
#endif

/*! @brief Get character string for an error code.
*
Expand Down Expand Up @@ -355,8 +357,13 @@ typedef enum
miopenBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction)
(Partially supported) */
miopenDouble = 6, /*!< 64-bit floating point (Partially supported) */
#ifdef MIOPEN_BETA_API
miopenFloat8 = 7,
miopenBFloat8 = 8
miopenBFloat8 = 8,
#else
//miopenReserved1 = 7,
//miopenReserved2 = 8,
#endif
cderb marked this conversation as resolved.
Show resolved Hide resolved
} miopenDataType_t;

/*! @ingroup tensor
Expand Down Expand Up @@ -601,11 +608,15 @@ typedef enum
MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC =
1, /*!< Restrict MIOpen convolutions to kernels which produce numerically deterministic
results. 0 - disabled (default), 1 - enabled >*/
#ifdef MIOPEN_BETA_API
cderb marked this conversation as resolved.
Show resolved Hide resolved
MIOPEN_CONVOLUTION_ATTRIB_FP8_ROUNDING_MODE =
2, /*!<Specifies the rounding mode for the 8-bit floating data types. Currently, two
rounding modes are supported miopenF8RoundingModeStandard and
miopenF8RoundingModeStochastic. These are listed as part of the miopenF8RoundingMode_t
enum.>*/
#else
//miopenReserved1 = 2,
#endif
} miopenConvolutionAttrib_t;

/** @addtogroup tensor
Expand Down Expand Up @@ -723,6 +734,7 @@ MIOPEN_EXPORT miopenStatus_t miopenSetTensorDescriptor(miopenTensorDescriptor_t
const int* dimsA,
const int* stridesA);

#ifdef MIOPEN_BETA_API
/*! @brief Set the tensor cast type
*
* For tensors where the cast_type attribute is set, the tensor elements would be converted to the
Expand All @@ -734,6 +746,7 @@ MIOPEN_EXPORT miopenStatus_t miopenSetTensorDescriptor(miopenTensorDescriptor_t
*/
MIOPEN_EXPORT miopenStatus_t miopenSetTensorCastType(miopenTensorDescriptor_t tensorDesc,
miopenDataType_t cast_type);
#endif

/*! @brief Set shape of N-dimensional tensor
*
Expand Down
2 changes: 2 additions & 0 deletions src/check_numerics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ std::string GetKernelName(miopenDataType_t data_type)
case miopenFloat: return {"check_numerics_fp32"};
case miopenHalf: return {"check_numerics_fp16"};
case miopenBFloat16: return {"check_numerics_bf16"};
#ifdef MIOPEN_BETA_API
case miopenFloat8: return {"check_numerics_fp8"};
case miopenBFloat8: return {"check_numerics_bf8"};
#endif
case miopenInt32:
case miopenInt8:
case miopenInt8x4:
Expand Down
4 changes: 4 additions & 0 deletions src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ void ConvolutionAttribute::Set(miopenConvolutionAttrib_t attr, int value)
std::to_string(value));
deterministic.value = value;
}
#ifdef MIOPEN_BETA_API
else if(attr == MIOPEN_CONVOLUTION_ATTRIB_FP8_ROUNDING_MODE)
{
const auto rounding_mode = static_cast<miopenF8RoundingMode_t>(value);
Expand All @@ -529,6 +530,7 @@ void ConvolutionAttribute::Set(miopenConvolutionAttrib_t attr, int value)
std::to_string(value));
fp8rounding_mode.rounding_mode = rounding_mode;
}
#endif
else
{
MIOPEN_THROW(miopenStatusBadParm,
Expand All @@ -541,8 +543,10 @@ int ConvolutionAttribute::Get(miopenConvolutionAttrib_t attr) const
{
if(attr == MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL)
return gfx90aFp16alt.value;
#ifdef MIOPEN_BETA_API
else if(attr == MIOPEN_CONVOLUTION_ATTRIB_FP8_ROUNDING_MODE)
return static_cast<int>(fp8rounding_mode.rounding_mode);
#endif
else if(attr == MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC)
return deterministic.value;
MIOPEN_THROW(miopenStatusBadParm,
Expand Down
29 changes: 28 additions & 1 deletion src/gemm_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ FlagsForRocblasFp32Fp16Call(const miopen::GemmDescriptor& desc) // bool gfx90aFp
#if USE_ROCBLAS_GEMM_EX3
static inline rocblas_computetype rocBlasComputeType_ex3(const miopen::GemmDescriptor& desc)
{
#ifdef MIOPEN_BETA_API
if(desc.a_cast_type == miopenFloat8 && desc.b_cast_type == miopenFloat8)
return rocblas_compute_type_f8_f8_f32;
else if(desc.a_cast_type == miopenFloat8 && desc.b_cast_type == miopenBFloat8)
Expand All @@ -103,6 +104,10 @@ static inline rocblas_computetype rocBlasComputeType_ex3(const miopen::GemmDescr
return rocblas_compute_type_bf8_bf8_f32;
else
return rocblas_compute_type_f32;
#else
(void)(desc);
return rocblas_compute_type_f32;
#endif
}
#endif

Expand All @@ -118,11 +123,13 @@ static inline rocblas_datatype rocBlasComputeType(const miopen::GemmDescriptor&

auto rocBlasDataType(miopenDataType_t data_type)
{
#ifdef MIOPEN_BETA_API
if(data_type == miopenFloat8)
return rocblas_datatype::rocblas_datatype_f8_r;
else if(data_type == miopenBFloat8)
return rocblas_datatype::rocblas_datatype_bf8_r;
else if(data_type == miopenHalf)
#endif
if(data_type == miopenHalf)
return rocblas_datatype::rocblas_datatype_f16_r;
MIOPEN_THROW(miopenStatusInternalError, "Invalid data type passed");
}
Expand All @@ -146,8 +153,10 @@ rocblas_status miopen_rocblas_gemm_ex3(const miopen::Handle& handle,
float alpha = gemm_desc.alpha;
float beta = gemm_desc.beta;
auto flags = FlagsForRocblasFp32Fp16Call(gemm_desc);
#ifdef MIOPEN_BETA_API
if(gemm_desc.conv_attributes.fp8rounding_mode.Get() == miopenF8RoundingModeStochastic)
flags = flags | rocblas_gemm_flags::rocblas_gemm_flags_stochastic_rounding;
#endif

rb_status = // cppcheck-suppress redundantInitialization
rocblas_gemm_ex3(handle.rhandle().get(),
Expand Down Expand Up @@ -485,6 +494,7 @@ miopenStatus_t CallGemm(const Handle& handle,
const auto is_gfx94x = miopen::StartsWith(handle.GetDeviceName(), "gfx94");
// We need ex3 API if any of the dataType or the cast type is an 8-bit floating type
const auto needs_ex3 = [&]() {
#ifdef MIOPEN_BETA_API
if((gemm_desc.dataType == miopenFloat8 || gemm_desc.dataType == miopenBFloat8) ||
(gemm_desc.a_cast_type == miopenFloat8 ||
gemm_desc.a_cast_type == miopenBFloat8) ||
Expand All @@ -493,6 +503,9 @@ miopenStatus_t CallGemm(const Handle& handle,
return true;
else
return false;
#else
return false;
#endif
}();
// ex3 API only works on the gfx94x ASIC;
if(needs_ex3)
Expand Down Expand Up @@ -607,6 +620,7 @@ miopenStatus_t CallGemm(const Handle& handle,
}
break;

#ifdef MIOPEN_BETA_API
case miopenFloat8:
case miopenBFloat8: {
const auto is_gfx94x = miopen::StartsWith(handle.GetDeviceName(), "gfx94");
Expand All @@ -620,6 +634,7 @@ miopenStatus_t CallGemm(const Handle& handle,
"8-bit floating types are only supported on gfx94x");
};
break;
#endif

case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
Expand Down Expand Up @@ -743,6 +758,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
const auto is_gfx94x = miopen::StartsWith(handle.GetDeviceName(), "gfx94");
// We need ex3 API if any of the dataType or the cast type is an 8-bit floating type
const auto needs_ex3 = [&]() {
#ifdef MIOPEN_BETA_API
if((gemm_desc.dataType == miopenFloat8 || gemm_desc.dataType == miopenBFloat8) ||
(gemm_desc.a_cast_type == miopenFloat8 ||
gemm_desc.a_cast_type == miopenBFloat8) ||
Expand All @@ -751,6 +767,9 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
return true;
else
return false;
#else
return false;
#endif
}();
// ex3 API only works on the gfx94x ASIC;
if(needs_ex3)
Expand Down Expand Up @@ -878,6 +897,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
}
break;

#ifdef MIOPEN_BETA_API
case miopenFloat8:
case miopenBFloat8: {
const auto is_gfx94x = miopen::StartsWith(handle.GetDeviceName(), "gfx94");
Expand All @@ -892,6 +912,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,

break;
}
#endif

case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
Expand Down Expand Up @@ -1016,6 +1037,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
const auto is_gfx94x = miopen::StartsWith(handle.GetDeviceName(), "gfx94");
// We need ex3 API if any of the dataType or the cast type is an 8-bit floating type
const auto needs_ex3 = [&]() {
#ifdef MIOPEN_BETA_API
if((gemm_desc.dataType == miopenFloat8 || gemm_desc.dataType == miopenBFloat8) ||
(gemm_desc.a_cast_type == miopenFloat8 ||
gemm_desc.a_cast_type == miopenBFloat8) ||
Expand All @@ -1024,6 +1046,9 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
return true;
else
return false;
#else
return false;
#endif
}();
// ex3 API only works on the gfx94x ASIC;
if(needs_ex3)
Expand Down Expand Up @@ -1148,6 +1173,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
}
break;

#ifdef MIOPEN_BETA_API
case miopenFloat8:
case miopenBFloat8: {
const auto is_gfx94x = miopen::StartsWith(handle.GetDeviceName(), "gfx94");
Expand All @@ -1162,6 +1188,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,

break;
}
#endif

case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
Expand Down
10 changes: 10 additions & 0 deletions src/include/miopen/conv/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ inline std::string GetDataTypeName(miopenDataType_t data_type)
case miopenInt32: return "INT32";
case miopenBFloat16: return "BF16";
case miopenDouble: return "FP64";
#ifdef MIOPEN_BETA_API
case miopenFloat8: return "FP8";
case miopenBFloat8: return "BFP8";
#endif
}

return "Unknown(" + std::to_string(data_type) + ")";
Expand Down Expand Up @@ -350,13 +352,21 @@ struct ProblemDescription : ProblemDescriptionBase
}
bool IsFp8() const
{
#ifdef MIOPEN_BETA_API
return GetInDataType() == miopenFloat8 || GetWeightsDataType() == miopenFloat8 ||
GetOutDataType() == miopenFloat8;
#else
return false;
#endif
}
bool IsBfp8() const
{
#ifdef MIOPEN_BETA_API
return GetInDataType() == miopenBFloat8 || GetWeightsDataType() == miopenBFloat8 ||
GetOutDataType() == miopenBFloat8;
#else
return false;
#endif
}
bool IsTensorsCasted() const
{
Expand Down
Loading