diff --git a/src/bindings/js/node/include/tensor.hpp b/src/bindings/js/node/include/tensor.hpp index 3ee9ae043c4e2d..e04c5acc2b5e5c 100644 --- a/src/bindings/js/node/include/tensor.hpp +++ b/src/bindings/js/node/include/tensor.hpp @@ -51,6 +51,8 @@ class TensorWrap : public Napi::ObjectWrap { Napi::Value get_shape(const Napi::CallbackInfo& info); /** @return Napi::String containing ov::element type. */ Napi::Value get_element_type(const Napi::CallbackInfo& info); + /**@return Napi::Number containing tensor size as total number of elements.*/ + Napi::Value get_size(const Napi::CallbackInfo& info); private: ov::Tensor _tensor; diff --git a/src/bindings/js/node/lib/addon.ts b/src/bindings/js/node/lib/addon.ts index fcd1d0681b625b..281430f27d7a6b 100644 --- a/src/bindings/js/node/lib/addon.ts +++ b/src/bindings/js/node/lib/addon.ts @@ -75,6 +75,7 @@ interface Tensor { getElementType(): element; getShape(): number[]; getData(): number[]; + getSize(): number; } interface TensorConstructor { new(type: element | elementTypeString, diff --git a/src/bindings/js/node/src/tensor.cpp b/src/bindings/js/node/src/tensor.cpp index 8f654dc0d70e52..9324e6abe3b90a 100644 --- a/src/bindings/js/node/src/tensor.cpp +++ b/src/bindings/js/node/src/tensor.cpp @@ -48,7 +48,8 @@ Napi::Function TensorWrap::get_class(Napi::Env env) { {InstanceAccessor<&TensorWrap::get_data>("data"), InstanceMethod("getData", &TensorWrap::get_data), InstanceMethod("getShape", &TensorWrap::get_shape), - InstanceMethod("getElementType", &TensorWrap::get_element_type)}); + InstanceMethod("getElementType", &TensorWrap::get_element_type), + InstanceMethod("getSize", &TensorWrap::get_size)}); } ov::Tensor TensorWrap::get_tensor() const { @@ -138,3 +139,13 @@ Napi::Value TensorWrap::get_shape(const Napi::CallbackInfo& info) { Napi::Value TensorWrap::get_element_type(const Napi::CallbackInfo& info) { return cpp_to_js(info, _tensor.get_element_type()); } + +Napi::Value TensorWrap::get_size(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + if (info.Length() > 0) { + reportError(env, "getSize() does not accept any arguments."); + return env.Undefined(); + } + const auto size = static_cast(_tensor.get_size()); + return Napi::Number::New(env, size); +} diff --git a/src/bindings/js/node/tests/tensor.test.js b/src/bindings/js/node/tests/tensor.test.js index e4268374fbcc63..b54d164f8c3604 100644 --- a/src/bindings/js/node/tests/tensor.test.js +++ b/src/bindings/js/node/tests/tensor.test.js @@ -136,3 +136,49 @@ describe('Tensor element type', () => { }); }); }); + + +describe('Tensor getSize', () => { + + it('getSize returns the correct total number of elements', () => { + const tensor = new ov.Tensor(ov.element.f32, shape, data); + const expectedSize = shape.reduce((acc, dim) => acc * dim, 1); + assert.strictEqual(tensor.getSize(), expectedSize); + }); + + it('getSize should throw an error if arguments are provided', () => { + const tensor = new ov.Tensor(ov.element.f32, shape, data); + assert.throws( + () => tensor.getSize(1), + { message: 'getSize() does not accept any arguments.' } + ); + }); +}); + +describe('Tensor getSize for various shapes', () => { + + it('calculates size correctly for a common image data shape [3, 224, 224]', () => { + const shape = [3, 224, 224]; + const expectedSize = 3*224*224; + const tensorData = new Float32Array(expectedSize).fill(0); + const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); + assert.strictEqual(tensor.getSize(), expectedSize); + }); + + it('calculates size correctly for a scalar wrapped in a tensor [1]', () => { + const shape = [1]; + const expectedSize = 1; + const tensorData = new Float32Array(expectedSize).fill(0); + const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); + assert.strictEqual(tensor.getSize(), expectedSize); + }); + + it('calculates size correctly for a vector [10]', () => { + const shape = [10]; + const expectedSize = 10; + const tensorData = new Float32Array(expectedSize).fill(0); + const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); + assert.strictEqual(tensor.getSize(), expectedSize); + }); +}); +