Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Tensor.get_size() method to Node.js API #23498

Merged
merged 10 commits into from
Mar 28, 2024
3 changes: 3 additions & 0 deletions src/bindings/js/node/include/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class TensorWrap : public Napi::ObjectWrap<TensorWrap> {
Napi::Value get_shape(const Napi::CallbackInfo& info);
/** @return Napi::String containing ov::element type. */
Napi::Value get_element_type(const Napi::CallbackInfo& info);
/**@return Napi::Number containing tensor size as total number of elements.*/
Napi::Value get_size(const Napi::CallbackInfo& info);


private:
ov::Tensor _tensor;
Expand Down
1 change: 1 addition & 0 deletions src/bindings/js/node/lib/addon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ interface Tensor {
getElementType(): element;
getShape(): number[];
getData(): number[];
getSize(): number;
}
interface TensorConstructor {
new(type: element | elementTypeString,
Expand Down
17 changes: 16 additions & 1 deletion src/bindings/js/node/src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ Napi::Function TensorWrap::get_class(Napi::Env env) {
{InstanceAccessor<&TensorWrap::get_data>("data"),
InstanceMethod("getData", &TensorWrap::get_data),
InstanceMethod("getShape", &TensorWrap::get_shape),
InstanceMethod("getElementType", &TensorWrap::get_element_type)});
InstanceMethod("getElementType", &TensorWrap::get_element_type),
InstanceMethod("getSize", &TensorWrap::get_size)});
}

ov::Tensor TensorWrap::get_tensor() const {
Expand Down Expand Up @@ -138,3 +139,17 @@ Napi::Value TensorWrap::get_shape(const Napi::CallbackInfo& info) {
Napi::Value TensorWrap::get_element_type(const Napi::CallbackInfo& info) {
return cpp_to_js<ov::element::Type_t, Napi::String>(info, _tensor.get_element_type());
}



Napi::Value TensorWrap::get_size(const Napi::CallbackInfo& info) {
if (info.Length() > 0) {
reportError(info.Env(), "getSize() does not accept any arguments.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can extract info.Env() into variable if you use it several times.

return info.Env().Null();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose return info.Env().Undefined() in case of error. (I know that in the sources you can meet both variants, but it could be aligned later).

}
size_t size = _tensor.get_size();
Copy link
Contributor

@almilosz almilosz Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
size_t size = _tensor.get_size();
const auto size = static_cast<double>(_tensor.get_size());


double jsSize = static_cast<double>(size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use snake_case in c++ part

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@almilosz Thanks , for the feedback I have made the suggested changes.


return Napi::Number::New(info.Env(), jsSize);
}
70 changes: 70 additions & 0 deletions src/bindings/js/node/tests/tensor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,73 @@ describe('Tensor element type', () => {
});
});
});


describe('Tensor getSize', () => {

it('getSize returns the correct total number of elements', () => {
const tensor = new ov.Tensor(ov.element.f32, shape, data);
const expectedSize = shape.reduce((acc, dim) => acc * dim, 1);
assert.strictEqual(tensor.getSize(), expectedSize);
});

it('getSize should throw an error if arguments are provided', () => {
const tensor = new ov.Tensor(ov.element.f32, shape, data);
assert.throws(
() => tensor.getSize(1),
{ message: 'getSize() does not accept any arguments.' }
);
});
});


describe('Tensor getSize for various shapes', () => {
it('calculates size correctly for a common image data shape [3, 224, 224]', () => {
const shape = [3, 224, 224];
const expectedSize = 3*224*224;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});

it('calculates size correctly for a scalar wrapped in a tensor [1]', () => {
const shape = [1];
const expectedSize = 1;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});

it('calculates size correctly for a vector [10]', () => {
const shape = [10];
const expectedSize = 10;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});

it('calculates size correctly for a small 3D tensor [2, 5, 5]', () => {
const shape = [2, 5, 5];
const expectedSize = 2*5*5;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like that you've implemented great tests for implemented functionality. Thank you!
But I don't see difference between this test and calculates size correctly for a common image data shape [3, 224, 224]. I propose to remove current test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


it('calculates size correctly for a small square matrix [4, 4]', () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also extra, I propose to remove it.

const shape = [4, 4];
const expectedSize = 16;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});

it('calculates size correctly for another small 3D tensor [2, 3, 3]', () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this one, I propose to remove it.

const shape = [2, 3, 3];
const expectedSize = 18;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});
});

Loading