forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TF FE][JAX FE] Support latest TF 2.18, JAX 0.4.35 and NumPy 2.x (ope…
…nvinotoolkit#27246) **Details:** Support TF 2.18 and JAX 0.4.35 **Ticket:** TBD --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
- Loading branch information
Showing
15 changed files
with
126 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/jax/node_context.hpp" | ||
#include "openvino/op/erf.hpp" | ||
#include "openvino/op/subtract.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace std; | ||
using namespace ov; | ||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace jax { | ||
namespace op { | ||
|
||
OutputVector translate_erfc(const NodeContext& context) { | ||
num_inputs_check(context, 1, 1); | ||
auto x = context.get_input(0); | ||
|
||
// create const one of the same type as x | ||
auto const_one = create_same_type_const_scalar<int64_t>(x, 1); | ||
Output<Node> res = make_shared<v0::Erf>(x); | ||
res = make_shared<v1::Subtract>(const_one, res); | ||
return {res}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace jax | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
from jax import numpy as jnp | ||
|
||
from jax_layer_test_class import JaxLayerTest | ||
|
||
rng = np.random.default_rng(109734) | ||
|
||
|
||
class TestErfc(JaxLayerTest): | ||
def _prepare_input(self): | ||
# erf are mostly changing in a range [-4, 4] | ||
x = rng.uniform(-4.0, 4.0, self.input_shape).astype(self.input_type) | ||
|
||
x = jnp.array(x) | ||
return [x] | ||
|
||
def create_model(self, input_shape, input_type): | ||
self.input_shape = input_shape | ||
self.input_type = input_type | ||
|
||
def jax_erfc(x): | ||
return jax.lax.erfc(x) | ||
|
||
return jax_erfc, None, 'erfc' | ||
|
||
@pytest.mark.parametrize("input_shape", [[2], [3, 4]]) | ||
@pytest.mark.parametrize("input_type", [np.float16, np.float32, np.float64]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit_jax_fe | ||
def test_erfc(self, ie_device, precision, ir_version, input_shape, input_type): | ||
self._test(*self.create_model(input_shape, input_type), | ||
ie_device, precision, | ||
ir_version) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,21 @@ | ||
# tensorflow-intel inside tensorflow still requires numpy<2.0.0 | ||
numpy==1.26.4 | ||
# test ovc with NumPy 2.x on Ubuntu 24 with default Python 3.12 | ||
# test against NumPy 1.x with older Python versions | ||
# tensorflow-intel 2.18.0 depends on numpy<2.1.0 and >=1.26.0 | ||
numpy==1.26.4; python_version < "3.12" | ||
numpy==2.0.2; python_version >= "3.12" | ||
pytest==7.0.1 | ||
pytest-xdist[psutil]==3.6.1 | ||
pytest-html==4.1.1 | ||
transformers==4.45.1 | ||
# install exact keras version since tensorflow depends and has no upper bound for it | ||
keras==3.6.0 | ||
tensorflow==2.17.0; platform_system != "Darwin" or platform_machine != "x86_64" | ||
tensorflow==2.18.0; python_version >= "3.12" and (platform_system != "Darwin" or platform_machine != "x86_64") | ||
tensorflow==2.17.0; python_version < "3.12" and (platform_system != "Darwin" or platform_machine != "x86_64") | ||
tensorflow==2.16.2; platform_system == "Darwin" and platform_machine == "x86_64" | ||
# install explicit version of wrapt to avoid "this __dict__ descriptor does not support '_DictWrapper' objects" error from TensorFlow 2.18 | ||
wrapt==1.15.0; python_version >= "3.12" | ||
# tensorflow-text is not available for both Windows and ARM platforms | ||
tensorflow-text==2.17.0; platform_system == "Linux" and platform_machine == "x86_64" | ||
tensorflow-text==2.17.0; python_version < "3.12" and platform_system == "Linux" and platform_machine == "x86_64" | ||
tensorflow-hub==0.16.1 | ||
jax==0.4.33 | ||
jax==0.4.35 | ||
defusedxml==0.7.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters