Skip to content

Commit

Permalink
[OV][JS] Expose the Tensor.isContinuous to Node.js API
Browse files Browse the repository at this point in the history
* Add a TensorWrap::is_continuous function:
  Calls the underlying Tensor.isContinous function
* Update the addon.ts file with the isContinuous method
* Add unit tests for the isContinuous Api

Closes: #27701

Signed-off-by: Nashez Zubair <nashezzubair@gmail.com>
  • Loading branch information
nashez committed Dec 25, 2024
1 parent 35f2a0c commit b11e0e9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/bindings/js/node/include/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class TensorWrap : public Napi::ObjectWrap<TensorWrap> {
Napi::Value get_element_type(const Napi::CallbackInfo& info);
/** @return Napi::Number containing tensor size as total number of elements. */
Napi::Value get_size(const Napi::CallbackInfo& info);
/**
* @brief Getter to check if tensor is continuous
* @return Napi::Boolean
*/
Napi::Value is_continuous(const Napi::CallbackInfo& info);

private:
ov::Tensor _tensor;
Expand Down
4 changes: 4 additions & 0 deletions src/bindings/js/node/lib/addon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@ interface Tensor {
* It gets the tensor size as a total number of elements.
*/
getSize(): number;
/**
* Reports whether the tensor is continuous or not.
*/
isContinuous(): boolean;
}

/**
Expand Down
12 changes: 11 additions & 1 deletion src/bindings/js/node/src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ Napi::Function TensorWrap::get_class(Napi::Env env) {
InstanceMethod("getData", &TensorWrap::get_data),
InstanceMethod("getShape", &TensorWrap::get_shape),
InstanceMethod("getElementType", &TensorWrap::get_element_type),
InstanceMethod("getSize", &TensorWrap::get_size)});
InstanceMethod("getSize", &TensorWrap::get_size),
InstanceMethod("isContinuous", &TensorWrap::is_continuous)});
}

ov::Tensor TensorWrap::get_tensor() const {
Expand Down Expand Up @@ -181,3 +182,12 @@ Napi::Value TensorWrap::get_size(const Napi::CallbackInfo& info) {
const auto size = static_cast<double>(_tensor.get_size());
return Napi::Number::New(env, size);
}

Napi::Value TensorWrap::is_continuous(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
if (info.Length() > 0) {
reportError(env, "isContinuous() does not accept any arguments.");
return env.Undefined();
}
return Napi::Boolean::New(env, _tensor.is_continuous());
}
14 changes: 14 additions & 0 deletions src/bindings/js/node/tests/unit/tensor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,18 @@ describe('ov.Tensor tests', () => {
assert.strictEqual(tensor.getSize(), expectedSize);
});
});

describe('Tensor isContinuous', () => {
it('isContinuous returns true if tensor is continuous', () => {
const tensor = new ov.Tensor(ov.element.f32, [3, 2, 2]);
assert.strictEqual(tensor.isContinuous(), true);
});

it('isContinuous should throw an error if arguments are provided', () => {
const tensor = new ov.Tensor(ov.element.f32, shape, data);
assert.throws(() => tensor.isContinuous(1), {
message: 'isContinuous() does not accept any arguments.',
});
});
});
});

0 comments on commit b11e0e9

Please sign in to comment.