Skip to content

Commit

Permalink
Productizing cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gkrivor committed Oct 21, 2024
1 parent db5ec58 commit 361fd89
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 69 deletions.
43 changes: 20 additions & 23 deletions src/frontends/onnx/frontend/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ using namespace ov;
using namespace ov::frontend::onnx;
using namespace ov::frontend::onnx::common;
using ::ONNX_NAMESPACE::ModelProto;

typedef std::shared_ptr<ModelProto> ModelProtoPtr;
using ::ONNX_NAMESPACE::Version;

ONNX_FRONTEND_C_API ov::frontend::FrontEndVersion get_api_version() {
return OV_FRONTEND_API_VERSION;
Expand Down Expand Up @@ -86,20 +85,17 @@ InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& variants) const
#endif
return std::make_shared<InputModel>(*stream, enable_mmap, m_extensions);
}
if (variants[0].is<ModelProtoPtr>()) {
std::cerr << "shared_ptr<ModelProto> has been received\n";
return std::make_shared<InputModel>(variants[0].as<ModelProtoPtr>(), m_extensions);
}
if (variants[0].is<ModelProto*>()) {
std::cerr << "ModelProto* has been received\n";
return std::make_shared<InputModel>(std::make_shared<ModelProto>(*variants[0].as<ModelProto*>()), m_extensions);
}
// !!! Experimental feature, it may be changed or removed in the future !!!
if (variants[0].is<uint64_t>()) {
std::cerr << "uint64_t as a ModelProto* has been received\n";
void* model_proto_ptr = reinterpret_cast<void*>(variants[0].as<uint64_t>());
return std::make_shared<InputModel>(std::make_shared<ModelProto>(*static_cast<ModelProto*>(model_proto_ptr)),
m_extensions);
void* model_proto_addr = reinterpret_cast<void*>(variants[0].as<uint64_t>());
FRONT_END_GENERAL_CHECK(model_proto_addr != 0, "Wrong address of a ModelProto object is passed");
ModelProto* model_proto_ptr = static_cast<ModelProto*>(model_proto_addr);
FRONT_END_GENERAL_CHECK(
model_proto_ptr->has_ir_version() && model_proto_ptr->ir_version() < Version::IR_VERSION,
"A ModelProto object contains unsupported IR version");
return std::make_shared<InputModel>(std::make_shared<ModelProto>(*model_proto_ptr), m_extensions);
}
// !!! End of Experimental feature
return nullptr;
}

Expand Down Expand Up @@ -230,18 +226,19 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
StreamRewinder rwd{*stream};
return is_valid_model(*stream);
}
if (variants[0].is<ModelProtoPtr>()) {
std::cerr << "shared_ptr<ModelProto> is supported\n";
return true;
}
if (variants[0].is<ModelProto*>()) {
std::cerr << "ModelProto* is supported\n";
return true;
}
// !!! Experimental feature, it may be changed or removed in the future !!!
if (variants[0].is<uint64_t>()) {
std::cerr << "uint64_t as a ModelProto* is supported\n";
void* model_proto_addr = reinterpret_cast<void*>(variants[0].as<uint64_t>());
if (model_proto_addr == 0) {
return false;
}
ModelProto* model_proto_ptr = static_cast<ModelProto*>(model_proto_addr);
if (!model_proto_ptr->has_ir_version() || model_proto_ptr->ir_version() > Version::IR_VERSION) {
return false;
}
return true;
}
// !!! End of Experimental feature
return false;
}

Expand Down
75 changes: 29 additions & 46 deletions src/frontends/onnx/tests/load_from.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ INSTANTIATE_TEST_SUITE_P(ONNXLoadTest,
::testing::Values(getTestData()),
FrontEndLoadFromTest::getTestCaseName);

// !!! Experimental feature, it may be changed or removed in the future !!!
using ::ONNX_NAMESPACE::ModelProto;
using ::ONNX_NAMESPACE::Version;

TEST_P(FrontEndLoadFromTest, testLoadFromModelProtoSharedPtr) {
TEST_P(FrontEndLoadFromTest, testLoadFromModelProtoUint64) {
const auto path =
ov::util::path_join({ov::test::utils::getExecutableDirectory(), TEST_ONNX_MODELS_DIRNAME, "abs.onnx"});
std::ifstream ifs(path, std::ios::in | std::ios::binary);
Expand All @@ -77,36 +79,12 @@ TEST_P(FrontEndLoadFromTest, testLoadFromModelProtoSharedPtr) {
auto model_proto = std::make_shared<ModelProto>();
ASSERT_TRUE(model_proto->ParseFromIstream(&ifs)) << "Could not parse ModelProto from file: " << path;

ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(model_proto))
<< "Could not create the ONNX FE using a shared_ptr on a ModelProto object";
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(model_proto)) << "Could not load the model";
ASSERT_NE(m_inputModel, nullptr);
}

std::shared_ptr<ov::Model> model;
ASSERT_NO_THROW(model = m_frontEnd->convert(m_inputModel)) << "Could not convert the model to OV representation";
ASSERT_NE(model, nullptr);

ASSERT_TRUE(model->get_ordered_ops().size() > 0);
}

TEST_P(FrontEndLoadFromTest, testLoadFromModelProtoPtr) {
const auto path =
ov::util::path_join({ov::test::utils::getExecutableDirectory(), TEST_ONNX_MODELS_DIRNAME, "abs.onnx"});
std::ifstream ifs(path, std::ios::in | std::ios::binary);
ASSERT_TRUE(ifs.is_open()) << "Could not open an ifstream for the model path: " << path;
std::vector<std::string> frontends;
FrontEnd::Ptr fe;

{
auto model_proto = std::make_shared<ModelProto>();
ASSERT_TRUE(model_proto->ParseFromIstream(&ifs)) << "Could not parse ModelProto from file: " << path;
uint64_t model_proto_ptr = reinterpret_cast<uint64_t>(model_proto.get());

ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(model_proto.get()))
<< "Could not create the ONNX FE using a pointer on ModelProto object";
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(model_proto_ptr))
<< "Could not create the ONNX FE using a pointer on ModelProto object as uint64_t";
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(model_proto.get())) << "Could not load the model";
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(model_proto_ptr)) << "Could not load the model";
ASSERT_NE(m_inputModel, nullptr);
}

Expand All @@ -117,30 +95,35 @@ TEST_P(FrontEndLoadFromTest, testLoadFromModelProtoPtr) {
ASSERT_TRUE(model->get_ordered_ops().size() > 0);
}

TEST_P(FrontEndLoadFromTest, testLoadFromModelProtoUint64) {
TEST_P(FrontEndLoadFromTest, testLoadFromModelProtoUint64_Negative) {
const auto path =
ov::util::path_join({ov::test::utils::getExecutableDirectory(), TEST_ONNX_MODELS_DIRNAME, "abs.onnx"});
std::ifstream ifs(path, std::ios::in | std::ios::binary);
ASSERT_TRUE(ifs.is_open()) << "Could not open an ifstream for the model path: " << path;
std::vector<std::string> frontends;
FrontEnd::Ptr fe;

{
auto model_proto = std::make_shared<ModelProto>();
ASSERT_TRUE(model_proto->ParseFromIstream(&ifs)) << "Could not parse ModelProto from file: " << path;

uint64_t model_proto_ptr = reinterpret_cast<uint64_t>(model_proto.get());

ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(model_proto_ptr))
<< "Could not create the ONNX FE using a pointer on ModelProto object as uint64_t";
ASSERT_NE(m_frontEnd, nullptr);
ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(model_proto_ptr)) << "Could not load the model";
ASSERT_NE(m_inputModel, nullptr);
}
auto model_proto = std::make_shared<ModelProto>();
ASSERT_TRUE(model_proto->ParseFromIstream(&ifs)) << "Could not parse ModelProto from file: " << path;

std::shared_ptr<ov::Model> model;
ASSERT_NO_THROW(model = m_frontEnd->convert(m_inputModel)) << "Could not convert the model to OV representation";
ASSERT_NE(model, nullptr);
uint64_t model_proto_ptr = reinterpret_cast<uint64_t>(model_proto.get());

ASSERT_TRUE(model->get_ordered_ops().size() > 0);
ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(model_proto_ptr))
<< "Could not create the ONNX FE using a pointer on ModelProto object as uint64_t";
ASSERT_NE(m_frontEnd, nullptr);
// Should say unsupported if an address is 0
ASSERT_FALSE(m_frontEnd->supported(static_cast<uint64_t>(0)));
// Should throw an ov::Exception if address is 0
OV_EXPECT_THROW(m_inputModel = m_frontEnd->load(static_cast<uint64_t>(0)),
ov::Exception,
testing::HasSubstr("Wrong address"));

model_proto->set_ir_version(Version::IR_VERSION + 1);
// Should say unsupported if ModelProto has IR_VERSION higher than supported
ASSERT_FALSE(m_frontEnd->supported(model_proto_ptr));
// Should throw an ov::Exception if address is 0
OV_EXPECT_THROW(m_inputModel = m_frontEnd->load(model_proto_ptr),
ov::Exception,
testing::HasSubstr("unsupported IR version"));
}
// !!! End of Experimental feature !!!

0 comments on commit 361fd89

Please sign in to comment.