Skip to content

Commit

Permalink
[MIGraphX EP] Add migraphx ep save load compiles (microsoft#20643) (m…
Browse files Browse the repository at this point in the history
…icrosoft#42)

Adds the ability for MIGraphX EP to save off or load compiled models to
save time between inferences.

Via Command line

User should be able to set the save ability with
ORT_MIGRAPHX_SAVE_COMPILED_MODEL
ORT_MIGRAPHX_SAVE_COMPILE_PATH

User should be able to set the load ability with
ORT_MIGRAPHX_LOAD_COMPILED_MODEL
ORT_MIGRAPHX_LOAD_COMPILE_PATH

via Onnxruntime API

migx_save_compiled_model
migx_save_model_name
migx_load_compiled_model
migx_load_model_name

The motivation for this is to leverage MIGraphX's existing API to
save/load models after our compile step of graph optimization. For
larger models or models which were compiled with additional tuning
steps, this saves time after first compile and inference run, and thus
speeds up the user experience in order to encourage development.

---------

Co-authored-by: Ted Themistokleous <tedthemistokleous@amd.com>
  • Loading branch information
TedThemistokleous and Ted Themistokleous authored Jun 21, 2024
1 parent 69e234a commit 2e20edc
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 86 deletions.
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,10 @@ typedef struct OrtMIGraphXProviderOptions {
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true
const char* migraphx_save_model_path; // migraphx model path name
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
231 changes: 148 additions & 83 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";

}; // namespace migraphx_env_vars

// Information to construct kernel function state.
Expand All @@ -39,6 +44,10 @@ struct MIGraphXFuncState {
bool int8_enable = false;
bool int8_calibration_cache_available = false;
std::unordered_map<std::string, float> dynamic_range_map;
bool save_compiled_mode = false;
std::string save_compiled_path;
bool load_compiled_mode = false;
std::string load_compiled_path;
bool dump_model_ops = false;
};

Expand Down Expand Up @@ -82,7 +91,11 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
bool int8_calibration_cache_available_ = false;
bool int8_use_native_migraphx_calibration_table_ = false;
std::string calibration_cache_path_;
std::unordered_map<std::string, float> dynamic_range_map_;
std::unordered_map<std::string, float> dynamic_range_map;
bool save_compiled_model_ = false;
std::string save_compiled_path_;
bool load_compiled_model_ = false;
std::string load_compiled_path_;
bool dump_model_ops_ = false;
migraphx::target t_;
OrtMutex mgx_mu_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
constexpr const char* kSaveModelPath = "migx_save_model_name";
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";

} // namespace provider_option_names
} // namespace migraphx
Expand All @@ -39,6 +43,8 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.Parse(options));

return info;
Expand All @@ -49,6 +55,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
};
return options;
}
Expand All @@ -58,6 +66,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ struct MIGraphXExecutionProviderInfo {
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};
bool save_compiled_model{true};
std::string save_model_file{"./compiled_model.mxr"};
bool load_compiled_model{true};
std::string load_model_file{"./compiled_model.mxr"};

static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ struct MIGraphX_Provider : Provider {
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name;
}
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
info.save_compiled_model = options.migraphx_save_compiled_model;
info.save_model_file = "";
if (options.migraphx_save_model_path != nullptr) {
info.save_model_file = options.migraphx_save_model_path;
}
info.load_compiled_model = options.migraphx_load_compiled_model;
info.load_model_file = "";
if (options.migraphx_load_model_path != nullptr) {
info.load_model_file = options.migraphx_load_model_path;
}
return std::make_shared<MIGraphXProviderFactory>(info);
}

Expand All @@ -92,6 +102,11 @@ struct MIGraphX_Provider : Provider {
}

migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table;

migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model;
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str();
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model;
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str();
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
46 changes: 45 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,14 +763,20 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else if (type == kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
std::string calibration_table;
std::string save_model_path;
std::string load_model_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtMIGraphXProviderOptions params{
0,
0,
0,
0,
nullptr};
nullptr,
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr"};
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
Expand Down Expand Up @@ -817,6 +823,44 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_compiled_model") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_model_path") {
if (!option.second.empty()) {
save_model_path = option.second;
params.migraphx_save_model_path = save_model_path.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_load_compiled_model") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_load_model_path") {
if (!option.second.empty()) {
load_model_path = option.second;
params.migraphx_load_model_path = load_model_path.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else {
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
}
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
0,
0,
0,
nullptr};
nullptr,
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr"};
return MIGraphXProviderFactoryCreator::Create(&params)->CreateProvider();
#else
return nullptr;
Expand Down

0 comments on commit 2e20edc

Please sign in to comment.