diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 3027c184d9..38245ca7f9 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -281,6 +282,8 @@ Error QnnManager::Init() { options_->backend_options()->backend_type()); backend_params_ptr_ = QnnBackendFactory().Create( qnn_loaded_backend_, logger_.get(), qnn_context_blob_, options_); + ET_CHECK_OR_RETURN_ERROR( + backend_params_ptr_ != nullptr, Internal, "Failed to load Qnn backend.") ET_CHECK_OR_RETURN_ERROR( backend_params_ptr_->qnn_backend_ptr_->Configure() == Error::Ok, Internal, diff --git a/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp b/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp index 3e286c07b0..c67f9b52f5 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp @@ -53,6 +53,85 @@ Error QnnBackend::Configure() { } return Error::Ok; } + +Error QnnBackend::VerifyQNNSDKVersion( + const QnnExecuTorchBackendType backend_id) { + const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + + Qnn_ApiVersion_t qnn_version = {QNN_VERSION_INIT}; + Qnn_ErrorHandle_t error = + qnn_interface.qnn_backend_get_api_version(&qnn_version); + if (error != QNN_SUCCESS) { + QNN_EXECUTORCH_LOG_ERROR("Failed to get Qnn API version."); + return Error::Internal; + } + + Qnn_ApiVersion_t expected_version = {QNN_VERSION_INIT}; + expected_version.coreApiVersion.major = QNN_API_VERSION_MAJOR; + expected_version.coreApiVersion.minor = QNN_API_VERSION_MINOR; + expected_version.coreApiVersion.patch = QNN_API_VERSION_PATCH; + expected_version.backendApiVersion = GetExpectedBackendVersion(); + const char* backend_type = EnumNameQnnExecuTorchBackendType(backend_id); + + Error status = VersionChecker( + qnn_version.coreApiVersion, expected_version.coreApiVersion, "Qnn API"); + if (status == Error::Ok) { + status = VersionChecker( + qnn_version.backendApiVersion, + expected_version.backendApiVersion, + backend_type); + } + + return status; +} + +Error QnnBackend::VersionChecker( + const Qnn_Version_t& qnn_version, + const Qnn_Version_t& expected, + const std::string& prefix) { + if (qnn_version.major != expected.major) { + QNN_EXECUTORCH_LOG_ERROR( + "%s version %u.%u.%u is not supported. " + "The minimum supported version is %u.%u.%u. Please make " + "sure you have the correct backend library version.", + prefix.c_str(), + qnn_version.major, + qnn_version.minor, + qnn_version.patch, + expected.major, + expected.minor, + expected.patch); + return Error::Internal; + } + if (qnn_version.major == QNN_API_VERSION_MAJOR && + qnn_version.minor < expected.minor) { + QNN_EXECUTORCH_LOG_WARN( + "%s version %u.%u.%u is mismatched. " + "The minimum supported version is %u.%u.%u. Please make " + "sure you have the correct backend library version.", + prefix.c_str(), + qnn_version.major, + qnn_version.minor, + qnn_version.patch, + expected.major, + expected.minor, + expected.patch); + } + if ((qnn_version.major == QNN_API_VERSION_MAJOR && + qnn_version.minor > expected.minor)) { + QNN_EXECUTORCH_LOG_WARN( + "%s version %u.%u.%u is used. " + "The version is tested against %u.%u.%u.", + prefix.c_str(), + qnn_version.major, + qnn_version.minor, + qnn_version.patch, + expected.major, + expected.minor, + expected.patch); + } + return Error::Ok; +} } // namespace qnn } // namespace executor } // namespace torch diff --git a/backends/qualcomm/runtime/backends/QnnBackendCommon.h b/backends/qualcomm/runtime/backends/QnnBackendCommon.h index e6ea0adff8..de007898e5 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCommon.h +++ b/backends/qualcomm/runtime/backends/QnnBackendCommon.h @@ -13,8 +13,10 @@ #include +#include "HTP/QnnHtpCommon.h" #include "QnnBackend.h" #include "QnnCommon.h" +#include "QnnTypes.h" namespace torch { namespace executor { namespace qnn { @@ -43,7 +45,10 @@ class QnnBackend { return handle_; } + Error VerifyQNNSDKVersion(const QnnExecuTorchBackendType backend_id); + protected: + virtual Qnn_Version_t GetExpectedBackendVersion() const = 0; virtual Error MakeConfig(std::vector& config) { return Error::Ok; }; @@ -52,6 +57,10 @@ class QnnBackend { Qnn_BackendHandle_t handle_; const QnnImplementation& implementation_; QnnLogger* logger_; + Error VersionChecker( + const Qnn_Version_t& qnn_version, + const Qnn_Version_t& expected, + const std::string& prefix); }; } // namespace qnn } // namespace executor diff --git a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp index acb9552468..9fb292613a 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp @@ -16,6 +16,7 @@ std::unique_ptr QnnBackendFactory::Create( const QnnExecuTorchContextBinary& qnn_context_blob, const QnnExecuTorchOptions* options) { auto backend_params = std::make_unique(); + switch (options->backend_options()->backend_type()) { case QnnExecuTorchBackendType::kHtpBackend: { auto htp_options = options->backend_options()->htp_options(); @@ -51,6 +52,7 @@ std::unique_ptr QnnBackendFactory::Create( } backend_params->qnn_backend_ptr_ = std::make_unique(implementation, logger); + backend_params->qnn_device_ptr_ = std::make_unique( implementation, logger, options->soc_info(), htp_options); @@ -72,7 +74,6 @@ std::unique_ptr QnnBackendFactory::Create( backend_params->qnn_mem_manager_ptr_ = std::make_unique( implementation, backend_params->qnn_context_ptr_.get()); backend_params->backend_init_state_ = BackendInitializeState::INITIALIZED; - return backend_params; } break; case QnnExecuTorchBackendType::kGpuBackend: case QnnExecuTorchBackendType::kDspBackend: @@ -81,7 +82,11 @@ std::unique_ptr QnnBackendFactory::Create( return nullptr; } - // should not reach here + if (backend_params->qnn_backend_ptr_->VerifyQNNSDKVersion( + options->backend_options()->backend_type()) == Error::Ok) { + return backend_params; + } + return nullptr; } } // namespace qnn diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpBackend.h b/backends/qualcomm/runtime/backends/htpbackend/HtpBackend.h index d4b14178a4..d00bd50cdc 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpBackend.h +++ b/backends/qualcomm/runtime/backends/htpbackend/HtpBackend.h @@ -8,7 +8,9 @@ #pragma once #include +#include "HTP/QnnHtpCommon.h" #include "HTP/QnnHtpProfile.h" +#include "QnnTypes.h" namespace torch { namespace executor { namespace qnn { @@ -24,6 +26,14 @@ class HtpBackend : public QnnBackend { event_type == QNN_HTP_PROFILE_EVENTTYPE_GRAPH_EXECUTE_ACCEL_TIME_CYCLE); } + Qnn_Version_t GetExpectedBackendVersion() const override { + Qnn_Version_t backend_version; + backend_version.major = QNN_HTP_API_VERSION_MAJOR; + backend_version.minor = QNN_HTP_API_VERSION_MINOR; + backend_version.patch = QNN_HTP_API_VERSION_PATCH; + return backend_version; + } + protected: Error MakeConfig(std::vector& config) override { return Error::Ok;