-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Tensor.get_size() method to Node.js API #23498
Changes from 3 commits
22464be
497d0e7
18177fa
81a1212
c2bed82
78012bf
43cbfad
f9507d0
c4751cf
06d9dcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -48,7 +48,8 @@ Napi::Function TensorWrap::get_class(Napi::Env env) { | |||||
{InstanceAccessor<&TensorWrap::get_data>("data"), | ||||||
InstanceMethod("getData", &TensorWrap::get_data), | ||||||
InstanceMethod("getShape", &TensorWrap::get_shape), | ||||||
InstanceMethod("getElementType", &TensorWrap::get_element_type)}); | ||||||
InstanceMethod("getElementType", &TensorWrap::get_element_type), | ||||||
InstanceMethod("getSize", &TensorWrap::get_size)}); | ||||||
} | ||||||
|
||||||
ov::Tensor TensorWrap::get_tensor() const { | ||||||
|
@@ -138,3 +139,17 @@ Napi::Value TensorWrap::get_shape(const Napi::CallbackInfo& info) { | |||||
Napi::Value TensorWrap::get_element_type(const Napi::CallbackInfo& info) { | ||||||
return cpp_to_js<ov::element::Type_t, Napi::String>(info, _tensor.get_element_type()); | ||||||
} | ||||||
|
||||||
|
||||||
|
||||||
Napi::Value TensorWrap::get_size(const Napi::CallbackInfo& info) { | ||||||
if (info.Length() > 0) { | ||||||
reportError(info.Env(), "getSize() does not accept any arguments."); | ||||||
return info.Env().Null(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I propose return |
||||||
} | ||||||
size_t size = _tensor.get_size(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
double jsSize = static_cast<double>(size); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use snake_case in c++ part There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @almilosz Thanks , for the feedback I have made the suggested changes. |
||||||
|
||||||
return Napi::Number::New(info.Env(), jsSize); | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,3 +136,73 @@ describe('Tensor element type', () => { | |
}); | ||
}); | ||
}); | ||
|
||
|
||
describe('Tensor getSize', () => { | ||
|
||
it('getSize returns the correct total number of elements', () => { | ||
const tensor = new ov.Tensor(ov.element.f32, shape, data); | ||
const expectedSize = shape.reduce((acc, dim) => acc * dim, 1); | ||
assert.strictEqual(tensor.getSize(), expectedSize); | ||
}); | ||
|
||
it('getSize should throw an error if arguments are provided', () => { | ||
const tensor = new ov.Tensor(ov.element.f32, shape, data); | ||
assert.throws( | ||
() => tensor.getSize(1), | ||
{ message: 'getSize() does not accept any arguments.' } | ||
); | ||
}); | ||
}); | ||
|
||
|
||
describe('Tensor getSize for various shapes', () => { | ||
it('calculates size correctly for a common image data shape [3, 224, 224]', () => { | ||
const shape = [3, 224, 224]; | ||
const expectedSize = 3*224*224; | ||
const tensorData = new Float32Array(expectedSize).fill(0); | ||
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); | ||
assert.strictEqual(tensor.getSize(), expectedSize); | ||
}); | ||
|
||
it('calculates size correctly for a scalar wrapped in a tensor [1]', () => { | ||
const shape = [1]; | ||
const expectedSize = 1; | ||
const tensorData = new Float32Array(expectedSize).fill(0); | ||
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); | ||
assert.strictEqual(tensor.getSize(), expectedSize); | ||
}); | ||
|
||
it('calculates size correctly for a vector [10]', () => { | ||
const shape = [10]; | ||
const expectedSize = 10; | ||
const tensorData = new Float32Array(expectedSize).fill(0); | ||
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); | ||
assert.strictEqual(tensor.getSize(), expectedSize); | ||
}); | ||
|
||
it('calculates size correctly for a small 3D tensor [2, 5, 5]', () => { | ||
const shape = [2, 5, 5]; | ||
const expectedSize = 2*5*5; | ||
const tensorData = new Float32Array(expectedSize).fill(0); | ||
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); | ||
assert.strictEqual(tensor.getSize(), expectedSize); | ||
}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like that you've implemented great tests for implemented functionality. Thank you! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
it('calculates size correctly for a small square matrix [4, 4]', () => { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also extra, I propose to remove it. |
||
const shape = [4, 4]; | ||
const expectedSize = 16; | ||
const tensorData = new Float32Array(expectedSize).fill(0); | ||
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); | ||
assert.strictEqual(tensor.getSize(), expectedSize); | ||
}); | ||
|
||
it('calculates size correctly for another small 3D tensor [2, 3, 3]', () => { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this one, I propose to remove it. |
||
const shape = [2, 3, 3]; | ||
const expectedSize = 18; | ||
const tensorData = new Float32Array(expectedSize).fill(0); | ||
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData); | ||
assert.strictEqual(tensor.getSize(), expectedSize); | ||
}); | ||
}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can extract
info.Env()
into variable if you use it several times.