-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PT FE] Add aten::atan2 #23003
Closed
Closed
[PT FE] Add aten::atan2 #23003
Changes from all commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
c65927d
Create atan2.cpp
rghvsh 38a3891
Update op_table.cpp
rghvsh 9d4141f
Create test_atan2.py
rghvsh ff54b63
Update atan2.cpp
rghvsh a80c3d8
Update test_atan2.py
rghvsh 9f435bb
Update test_atan2.py
rghvsh 3b9fd29
Update atan2.cpp
rghvsh 36fcd1d
Update atan2.cpp
rghvsh 004439f
Update test_atan2.py
rghvsh c5941bb
Update atan2.cpp
rghvsh 3c110e3
Update test_atan2.py
rghvsh 05c11cb
Update test_atan2.py
rghvsh 7c79090
Update atan2.cpp
rghvsh 61097ea
Update atan2.cpp
rghvsh 38bd4b1
Merge branch 'master' into ptfe-atan2
rghvsh d15d1a0
Update src/frontends/pytorch/src/op_table.cpp
rghvsh d0da16f
Update tests/layer_tests/pytorch_tests/test_atan2.py
rghvsh c15ca80
Update op_table.cpp
rghvsh 14f2a91
Update atan2.cpp
rghvsh c48f56d
Update atan2.cpp
rghvsh c7d5c39
Update test_atan2.py
rghvsh 0c6932f
Update atan2.cpp
rghvsh 82d718f
Update atan2.cpp
rghvsh 0da1b30
Update atan2.cpp
rghvsh eb6bdbe
Update atan2.cpp
rghvsh b72910c
Update atan2.cpp
rghvsh c8aa285
Update atan2.cpp
rghvsh f7409c7
Update test_atan2.py
rghvsh e2c6c6f
Update atan2.cpp
rghvsh 2cb6863
Update src/frontends/pytorch/src/op/atan2.cpp
rghvsh 59dc55b
Update src/frontends/pytorch/src/op/atan2.cpp
rghvsh 5e69f05
Merge branch 'master' into ptfe-atan2
rghvsh 020b5ae
Update test_atan2.py
rghvsh 4d31d2f
Update atan2.cpp
rghvsh 5cc47bb
Update atan2.cpp
rghvsh ae81808
Update atan2.cpp
rghvsh 0aa7174
Update op_table.cpp
rghvsh 24f0198
Update atan2.cpp
rghvsh c6c066c
Update test_atan2.py
rghvsh 7d38db0
Update atan2.cpp
rghvsh 9bc5e27
Update atan2.cpp
rghvsh 30f642c
Update test_atan2.py
rghvsh d6478a7
Merge branch 'master' into ptfe-atan2
mvafin 7a7d79d
Update atan2.cpp
rghvsh d0d8d2e
Update tests/layer_tests/pytorch_tests/test_atan2.py
rghvsh 0343b28
Merge branch 'master' into ptfe-atan2
rghvsh b1d67bf
Update atan2.cpp
rghvsh 022569b
Update atan2.cpp
rghvsh a0d2b15
Update test_atan2.py
rghvsh fff1cd1
Update test_atan2.py
rghvsh ed9df76
Update test_atan2.py
rghvsh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/atan.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/op/convert_like.hpp" | ||
#include "openvino/op/divide.hpp" | ||
#include "openvino/op/equal.hpp" | ||
#include "openvino/op/greater.hpp" | ||
#include "openvino/op/greater_eq.hpp" | ||
#include "openvino/op/less.hpp" | ||
#include "openvino/op/logical_and.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/range.hpp" | ||
#include "openvino/op/select.hpp" | ||
#include "openvino/op/subtract.hpp" | ||
#include "pt_framework_node.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
OutputVector translate_atan2(const NodeContext& context) { | ||
// Check whether inputs present | ||
num_inputs_check(context, 2, 3); | ||
// "aten::atan2.out(Tensor input,Tensor other, *,Tensor(a!) out) → Tensor(a!)" | ||
|
||
// get input tensor x and y | ||
Output<Node> x = context.get_input(0); | ||
Output<Node> y = context.get_input(1); | ||
auto dummy_const = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape({}), {0.5}))->output(0); | ||
|
||
// align input types of dummy_const, x | ||
align_eltwise_input_types(context, x, dummy_const, false, true); | ||
rghvsh marked this conversation as resolved.
Show resolved
Hide resolved
rghvsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// align input types of y, x | ||
align_eltwise_input_types(context, x, y, is_python_scalar_input(context, 0), is_python_scalar_input(context, 1)); | ||
|
||
// handle the first condition | ||
// x>0 | ||
auto div_y_x = context.mark_node(std::make_shared<v1::Divide>(y, x)); | ||
auto atan = context.mark_node(std::make_shared<v0::Atan>(div_y_x)); | ||
auto const_zero = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); | ||
const_zero = context.mark_node(std::make_shared<v1::ConvertLike>(const_zero, x)); | ||
auto result = atan->output(0); | ||
|
||
// handle the second condition | ||
// x<0 && y>=0 | ||
auto const_pi = context.mark_node(v0::Constant::create(element::f32, Shape{}, {std::atan(1.0) * 4})); | ||
// Same input type | ||
const_pi = context.mark_node(std::make_shared<v1::ConvertLike>(const_pi, x)); | ||
auto is_x_negative = context.mark_node(std::make_shared<v1::Less>(x, const_zero)); | ||
auto y_non_negative = context.mark_node(std::make_shared<v1::GreaterEqual>(y, const_zero)); | ||
auto cond1 = context.mark_node(std::make_shared<v1::LogicalAnd>(is_x_negative, y_non_negative)); | ||
auto atan_y_x_plus_pi = context.mark_node(std::make_shared<v1::Add>(atan, const_pi)); | ||
result = context.mark_node(std::make_shared<v1::Select>(cond1, atan_y_x_plus_pi, result)); | ||
|
||
// handle the third condition | ||
// x<0 && y<0 | ||
auto is_y_negative = context.mark_node(std::make_shared<v1::Less>(y, const_zero)); | ||
auto cond2 = context.mark_node(std::make_shared<v1::LogicalAnd>(is_x_negative, is_y_negative)); | ||
auto atan_y_x_minus_pi = context.mark_node(std::make_shared<v1::Subtract>(atan, const_pi)); | ||
result = context.mark_node(std::make_shared<v1::Select>(cond2, atan_y_x_minus_pi, result)); | ||
|
||
// handle the fourth condition | ||
// x=0 && y>0 | ||
auto is_x_zero = context.mark_node(std::make_shared<v1::Equal>(x, const_zero)); | ||
auto is_y_positive = context.mark_node(std::make_shared<v1::Greater>(y, const_zero)); | ||
auto cond3 = context.mark_node(std::make_shared<v1::LogicalAnd>(is_x_zero, is_y_positive)); | ||
auto const_two = context.mark_node(v0::Constant::create(element::f32, Shape{}, {2})); | ||
// Same type conversion | ||
const_two = context.mark_node(std::make_shared<v1::ConvertLike>(const_two, x)); | ||
auto pi_div_two = context.mark_node(std::make_shared<v1::Divide>(const_pi, const_two)); | ||
result = context.mark_node(std::make_shared<v1::Select>(cond3, pi_div_two, result)); | ||
|
||
// handle the fifth condition | ||
// x=0 && y<0 | ||
auto cond4 = context.mark_node(std::make_shared<v1::LogicalAnd>(is_x_zero, is_y_negative)); | ||
auto const_minus_two = context.mark_node(v0::Constant::create(element::f32, Shape{}, {-2})); | ||
// Same type conversion | ||
const_minus_two = context.mark_node(std::make_shared<v1::ConvertLike>(const_minus_two, x)); | ||
auto pi_div_minus_two = context.mark_node(std::make_shared<v1::Divide>(const_pi, const_minus_two)); | ||
result = context.mark_node(std::make_shared<v1::Select>(cond4, pi_div_two, result)); | ||
|
||
return {result}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0i | ||
|
||
import pytest | ||
import torch | ||
import math | ||
import numpy as np | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
class TestAtan2(PytorchLayerTest): | ||
def _prepare_input(self, y, x, dtype1=None, dtype2=None): | ||
inputs = [np.array(y).astype(dtype1), np.array(x).astype(dtype2)] | ||
return inputs | ||
|
||
def create_model(self, dtype1=None, dtype2=None, use_out=False): | ||
dtype_map = { | ||
"float32": torch.float32, | ||
"float64": torch.float64, | ||
"int64": torch.int64, | ||
"int32": torch.int32, | ||
"int16": torch.int16, | ||
"uint8": torch.uint8, | ||
"int8": torch.int8, | ||
} | ||
|
||
class aten_atan2_out(torch.nn.Module): | ||
def __init__(self, dtype) -> None: | ||
super().__init__() | ||
self.out = torch.empty(25, dtype=dtype) | ||
|
||
def forward(self, y, x): | ||
return torch.atan2(input = y, other = x, out=self.out) | ||
|
||
class aten_atan2(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, y, x): | ||
return torch.atan2(input = y, other = x) | ||
|
||
dtype1 = dtype_map.get(dtype1) | ||
dtype2 = dtype_map.get(dtype2) | ||
|
||
if use_out: | ||
model_class = aten_atan2_out(dtype1) | ||
else: | ||
model_class = aten_atan2() | ||
|
||
ref_net = None | ||
|
||
return model_class, ref_net, "aten::atan2" | ||
|
||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("dtype1, dtype2", [(None, None), ("float32", "int32"), ("float64", "float64"), ("int32", "float64"), ("int64", "int16"), ("int8", "int8"), ("uint8", "uint8")]) | ||
@pytest.mark.parametrize( | ||
"y, x", [(0, 1.5), (0, 0), (1.25, -5), (1, 10), (-1, -5.5), (-1, -5), (1.25, -5.5), (1.9, 2.9), [10, 9.9]] | ||
) | ||
@pytest.mark.parametrize("use_out", [False, True]) | ||
def test_atan2_with_out(self, dtype1, dtype2, use_out, y, x, ie_device, precision, ir_version): | ||
self._test( | ||
*self.create_model(dtype2=dtype2, dtype1=dtype1, use_out=use_out), | ||
ie_device, | ||
precision, | ||
ir_version, | ||
kwargs_to_prepare_input={"y": y, "x": x} | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this include