Skip to content

Commit

Permalink
Add ORT_MIGRAPHX_SET_FAST_MATH env option and api hooks
Browse files Browse the repository at this point in the history
Allow users to set the fast math option for MIGraphX compilation for quantized data types (fp16)
This allows us to toggle whether we can use faster math with the tradeoff of accuracy.
  • Loading branch information
Ted Themistokleous committed Nov 29, 2023
1 parent a352c01 commit 3819961
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 3 deletions.
13 changes: 10 additions & 3 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true);
}

// whether fp16 is enable
const std::string fast_math_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFastMathOptimization);
if (!fast_math_env.empty()) {
fast_math_enable_ = (std::stoi(fast_math_enable_env) == 0 ? false : true);
}

// whether int8 is enabled
const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable);
if (!int8_enable_env.empty()) {
Expand Down Expand Up @@ -168,6 +174,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: "
<< "device_id: " << device_id_
<< ", migraphx_fp16_enable: " << fp16_enable_
<< ", migraphx_fast_math: " << fast_math_enable_
<< ", migraphx_int8_enable: " << int8_enable_
<< ", dump_model_ops: " << dump_model_ops_
<< ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_
Expand Down Expand Up @@ -1145,7 +1152,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
migraphx::quantize_int8(prog, t_, quant_opts);
}
migraphx::compile_options co;
co.set_fast_math(false);
co.set_fast_math(fast_math_enable_);
prog.compile(t_, co);
auto prog_output_shapes = prog.get_output_shapes();
for (std::size_t i = 0; i < output_names.size(); ++i) {
Expand All @@ -1165,7 +1172,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
std::unique_ptr<MIGraphXFuncState> p = std::make_unique<MIGraphXFuncState>();
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_,
map_no_input_shape_[context->node_name], fp16_enable_, fast_math_enable_, int8_enable_,
int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_};
*state = p.release();
return 0;
Expand Down Expand Up @@ -1265,7 +1272,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
}

migraphx::compile_options co;
co.set_fast_math(false);
co.set_fast_math(fast_math_enable);
prog.compile(t, co);
mgx_state->prog = prog;
param_shapes = prog.get_parameter_shapes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
static const char kSetFastMathOptimization[] = "ORT_MIGRAPHX_SET_FAST_MATH";
}; // namespace migraphx_env_vars

// Information to construct kernel function state.
Expand All @@ -41,6 +42,7 @@ struct MIGraphXFuncState {
OrtMutex* mgx_mu_ptr = nullptr;
bool no_input_shape = false;
bool fp16_enable = false;
bool fast_math_enable = false;
bool int8_enable = false;
bool int8_calibration_cache_available = false;
std::unordered_map<std::string, float> dynamic_range_map;
Expand Down Expand Up @@ -78,6 +80,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {

private:
bool fp16_enable_ = false;
bool fast_math_enable_ = false;
bool int8_enable_ = false;
std::string int8_calibration_cache_name_;
bool int8_calibration_cache_available_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kFastMathEnable = "migx_fast_math_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
Expand All @@ -38,6 +39,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kFastMathEnable, info.fast_math_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.Parse(options));

Expand All @@ -48,6 +50,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kFastMathEnable, MakeStringWithClassicLocale(info.fast_math_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
};
return options;
Expand All @@ -57,6 +60,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kFastMathEnable, MakeStringWithClassicLocale(info.migraphx_fast_math_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
};
return options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo {
std::string target_device;
int device_id{0};
bool fp16_enable{false};
bool fast_math_enable{false};
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct MIGraphX_Provider : Provider {
info.device_id = options.device_id;
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.fast_math_enable = options.migraphx_fast_math_enable;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = "";
if (options.migraphx_int8_calibration_table_name != nullptr) {
Expand All @@ -61,6 +62,7 @@ struct MIGraphX_Provider : Provider {
auto& migx_options = *reinterpret_cast<OrtMIGraphXProviderOptions*>(provider_options);
migx_options.device_id = internal_options.device_id;
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
migx_options.migraphx_fast_math_enable = internal_options.fast_math_enable;
migx_options.migraphx_int8_enable = internal_options.int8_enable;

char* dest = nullptr;
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
0,
0,
0,
0,
nullptr};
for (auto option : it->second) {
if (option.first == "device_id") {
Expand All @@ -752,6 +753,13 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
}
else if (option.first == "migraphx_set_fast_math") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fast_math_enable = true;
} else {
params.migraphx_fast_math_enable = false;
}
} else if (option.first == "migraphx_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
Expand Down

0 comments on commit 3819961

Please sign in to comment.