Skip to content

Commit

Permalink
Implemented getOutputElementType (openvinotoolkit#25760)
Browse files Browse the repository at this point in the history
Implemented Method on c++ side.
Updated typescript definitions.
Created unit tests.
For Issue
[https://github.com/openvinotoolkit/openvino/issues/25406](https://github.com/openvinotoolkit/openvino/issues/25406)

Resolved merge errors

---------

Co-authored-by: Alicja Miloszewska <alicja.miloszewska@intel.com>
  • Loading branch information
Pey-crypto and almilosz authored Aug 3, 2024
1 parent d29948c commit 59a0f01
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/bindings/js/node/include/model_wrap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ class ModelWrap : public Napi::ObjectWrap<ModelWrap> {
* @return Napi::Array containing a shape of requested output.
*/
Napi::Value get_output_shape(const Napi::CallbackInfo& info);

/**
* @brief Helper function to access model output elements types.
* @return Napi::String representing the element type of the requested output.
*
*/
Napi::Value get_output_element_type(const Napi::CallbackInfo& info);

private:
std::shared_ptr<ov::Model> _model;
Expand Down
5 changes: 5 additions & 0 deletions src/bindings/js/node/lib/addon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ interface Model {
* It returns the number of the model outputs.
*/
getOutputSize(): number;
/**
* It gets the element type of a specific output of the model.
* @param index The index of the output.
*/
getOutputElementType(index: number): string;
/**
* It gets the input of the model.
* If a model has more than one input, this method throws an exception.
Expand Down
18 changes: 18 additions & 0 deletions src/bindings/js/node/src/model_wrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "node/include/errors.hpp"
#include "node/include/helper.hpp"
#include "node/include/node_output.hpp"
#include "node/include/type_validation.hpp"

ModelWrap::ModelWrap(const Napi::CallbackInfo& info)
: Napi::ObjectWrap<ModelWrap>(info),
Expand All @@ -25,6 +26,7 @@ Napi::Function ModelWrap::get_class(Napi::Env env) {
InstanceMethod("setFriendlyName", &ModelWrap::set_friendly_name),
InstanceMethod("getFriendlyName", &ModelWrap::get_friendly_name),
InstanceMethod("getOutputShape", &ModelWrap::get_output_shape),
InstanceMethod("getOutputElementType", &ModelWrap::get_output_element_type),
InstanceAccessor<&ModelWrap::get_inputs>("inputs"),
InstanceAccessor<&ModelWrap::get_outputs>("outputs")});
}
Expand Down Expand Up @@ -171,3 +173,19 @@ Napi::Value ModelWrap::get_output_shape(const Napi::CallbackInfo& info) {
return info.Env().Undefined();
}
}

Napi::Value ModelWrap::get_output_element_type(const Napi::CallbackInfo& info) {
std::vector<std::string> allowed_signatures;
try {
if (ov::js::validate<int>(info, allowed_signatures)) {
auto idx = info[0].As<Napi::Number>().Int32Value();
const auto& output = _model->output(idx);
return cpp_to_js<ov::element::Type_t, Napi::String>(info, output.get_element_type());
} else {
OPENVINO_THROW("'getOutputElementType'", ov::js::get_parameters_error_msg(info, allowed_signatures));
}
} catch (const std::exception& e) {
reportError(info.Env(), e.what());
return info.Env().Undefined();
}
}
45 changes: 45 additions & 0 deletions src/bindings/js/node/tests/unit/model.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,48 @@ describe('Model.getOutputSize()', () => {
assert.strictEqual(model.getOutputSize(), 1, 'Expected getOutputSize to return 1 for the default model');
});
});

describe('Model.getOutputElementType()', () => {
it('should return a string for the element type ', () => {
const result = model.getOutputElementType(0);
assert.strictEqual(typeof result, 'string',
'getOutputElementType() should return a string');
});

it('should accept a single integer argument', () => {
assert.throws(() => {
model.getOutputElementType();
}, /'getOutputElementType' method called with incorrect parameters/,
'Should throw when called without arguments');

assert.throws(() => {
model.getOutputElementType('unexpected argument');
}, /'getOutputElementType' method called with incorrect parameters/,
'Should throw on non-number argument');

assert.throws(() => {
model.getOutputElementType(0, 1);
}, /'getOutputElementType' method called with incorrect parameters/,
'Should throw on multiple arguments');

assert.throws(() => {
model.getOutputElementType(3.14);
}, /'getOutputElementType' method called with incorrect parameters/,
'Should throw on non-integer number');
});

it('should return a valid element type for the default model', () => {
const elementType = model.getOutputElementType(0);
assert.ok(typeof elementType === 'string' && elementType.length > 0,
`Expected a non-empty string, got ${elementType}`);
});

it('should throw an error for out-of-range index', () => {
const outputSize = model.getOutputSize();
assert.throws(
() => { model.getOutputElementType(outputSize); },
/^Error: /,
'Should throw for out-of-range index'
);
});
});

0 comments on commit 59a0f01

Please sign in to comment.