Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Add aten::atan2 #23003

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
c65927d
Create atan2.cpp
rghvsh Feb 22, 2024
38a3891
Update op_table.cpp
rghvsh Feb 22, 2024
9d4141f
Create test_atan2.py
rghvsh Feb 22, 2024
ff54b63
Update atan2.cpp
rghvsh Feb 22, 2024
a80c3d8
Update test_atan2.py
rghvsh Feb 22, 2024
9f435bb
Update test_atan2.py
rghvsh Feb 22, 2024
3b9fd29
Update atan2.cpp
rghvsh Mar 3, 2024
36fcd1d
Update atan2.cpp
rghvsh Mar 3, 2024
004439f
Update test_atan2.py
rghvsh Mar 3, 2024
c5941bb
Update atan2.cpp
rghvsh Mar 3, 2024
3c110e3
Update test_atan2.py
rghvsh Mar 3, 2024
05c11cb
Update test_atan2.py
rghvsh Mar 3, 2024
7c79090
Update atan2.cpp
rghvsh Mar 5, 2024
61097ea
Update atan2.cpp
rghvsh Mar 12, 2024
38bd4b1
Merge branch 'master' into ptfe-atan2
rghvsh Mar 12, 2024
d15d1a0
Update src/frontends/pytorch/src/op_table.cpp
rghvsh Mar 13, 2024
d0da16f
Update tests/layer_tests/pytorch_tests/test_atan2.py
rghvsh Mar 13, 2024
c15ca80
Update op_table.cpp
rghvsh Mar 13, 2024
14f2a91
Update atan2.cpp
rghvsh Mar 13, 2024
c48f56d
Update atan2.cpp
rghvsh Mar 13, 2024
c7d5c39
Update test_atan2.py
rghvsh Mar 13, 2024
0c6932f
Update atan2.cpp
rghvsh Mar 13, 2024
82d718f
Update atan2.cpp
rghvsh Mar 13, 2024
0da1b30
Update atan2.cpp
rghvsh Mar 13, 2024
eb6bdbe
Update atan2.cpp
rghvsh Mar 17, 2024
b72910c
Update atan2.cpp
rghvsh Mar 17, 2024
c8aa285
Update atan2.cpp
rghvsh Mar 17, 2024
f7409c7
Update test_atan2.py
rghvsh Mar 17, 2024
e2c6c6f
Update atan2.cpp
rghvsh Mar 17, 2024
2cb6863
Update src/frontends/pytorch/src/op/atan2.cpp
rghvsh Mar 19, 2024
59dc55b
Update src/frontends/pytorch/src/op/atan2.cpp
rghvsh Mar 19, 2024
5e69f05
Merge branch 'master' into ptfe-atan2
rghvsh Mar 23, 2024
020b5ae
Update test_atan2.py
rghvsh Mar 23, 2024
4d31d2f
Update atan2.cpp
rghvsh Mar 23, 2024
5cc47bb
Update atan2.cpp
rghvsh Mar 23, 2024
ae81808
Update atan2.cpp
rghvsh Mar 23, 2024
0aa7174
Update op_table.cpp
rghvsh Mar 23, 2024
24f0198
Update atan2.cpp
rghvsh Mar 23, 2024
c6c066c
Update test_atan2.py
rghvsh Mar 23, 2024
7d38db0
Update atan2.cpp
rghvsh Mar 23, 2024
9bc5e27
Update atan2.cpp
rghvsh Mar 25, 2024
30f642c
Update test_atan2.py
rghvsh Mar 26, 2024
d6478a7
Merge branch 'master' into ptfe-atan2
mvafin Mar 26, 2024
7a7d79d
Update atan2.cpp
rghvsh Mar 26, 2024
d0d8d2e
Update tests/layer_tests/pytorch_tests/test_atan2.py
rghvsh Mar 26, 2024
0343b28
Merge branch 'master' into ptfe-atan2
rghvsh Mar 26, 2024
b1d67bf
Update atan2.cpp
rghvsh Mar 26, 2024
022569b
Update atan2.cpp
rghvsh Mar 26, 2024
a0d2b15
Update test_atan2.py
rghvsh Mar 29, 2024
fff1cd1
Update test_atan2.py
rghvsh Apr 11, 2024
ed9df76
Update test_atan2.py
rghvsh Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/frontends/pytorch/src/op/atan2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/atan.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/subtract.hpp"
#include "pt_framework_node.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this include

#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_atan2(const NodeContext& context) {
// Check whether inputs present
num_inputs_check(context, 2, 3);
// "aten::atan2.out(Tensor input,Tensor other, *,Tensor(a!) out) → Tensor(a!)"
Output<Node> y;
Output<Node> x;
// tie inputs together
std::tie(y, x) = get_inputs_with_promoted_types(context, 0, 1);
auto dummy_const = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape({}), {0.5}))->output(0);
// align input types
align_eltwise_input_types(context, x, dummy_const, false, true);
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
rghvsh marked this conversation as resolved.
Show resolved Hide resolved

// handle the first condition : x>0
auto div_y_x = context.mark_node(std::make_shared<v1::Divide>(y, x));
auto atan = context.mark_node(std::make_shared<v0::Atan>(div_y_x));
auto const_zero = v0::Constant::create(element::f32, Shape{}, {0});
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
auto result = atan->output(0);

// handle the second condition : x<0 && y>=0
auto const_pi = v0::Constant::create(element::f32, Shape{}, {std::atan(1.0) * 4});
// Same input type
x = context.mark_node(std::make_shared<v1::ConvertLike>(x, const_pi));
auto is_x_negative = context.mark_node(std::make_shared<v1::Less>(x, const_zero));
y = context.mark_node(std::make_shared<v1::ConvertLike>(y, const_zero));
auto y_non_negative = context.mark_node(std::make_shared<v1::GreaterEqual>(y, const_zero));
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
auto cond1 = context.mark_node(std::make_shared<v1::LogicalAnd>(is_x_negative, y_non_negative));
atan = context.mark_node(std::make_shared<v1::ConvertLike>(atan, const_pi));
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
auto atan_y_x_plus_pi = context.mark_node(std::make_shared<v1::Add>(atan, const_pi));
result = context.mark_node(std::make_shared<v1::Select>(cond1, atan_y_x_plus_pi, result));

// handle the third condition : x<0 && y<0
y = context.mark_node(std::make_shared<v1::ConvertLike>(y, const_zero));
auto is_y_negative = context.mark_node(std::make_shared<v1::Less>(y, const_zero));
auto cond2 = context.mark_node(std::make_shared<v1::LogicalAnd>(is_x_negative, is_y_negative));
y = context.mark_node(std::make_shared<v1::ConvertLike>(atan, const_pi));
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
auto atan_y_x_minus_pi = context.mark_node(std::make_shared<v1::Subtract>(atan, const_pi));
result = context.mark_node(std::make_shared<v1::Select>(cond2, atan_y_x_minus_pi, result));

// handle the fourth condition : x=0 && y>0
x = context.mark_node(std::make_shared<v1::ConvertLike>(x, const_zero));
auto is_x_zero = context.mark_node(std::make_shared<v1::Equal>(x, const_zero));
y = context.mark_node(std::make_shared<v1::ConvertLike>(y, const_zero));
auto is_y_positive = context.mark_node(std::make_shared<v1::Greater>(y, const_zero));
auto cond3 = context.mark_node(std::make_shared<v1::LogicalAnd>(is_x_zero, is_y_positive));
auto const_two = v0::Constant::create(element::f32, Shape{}, {2});
const_pi = context.mark_node(std::make_shared<v1::ConvertLike>(const_pi, const_two));
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
auto pi_div_two = context.mark_node(std::make_shared<v1::Divide>(const_pi, const_two));
result = context.mark_node(std::make_shared<v1::Select>(cond3, pi_div_two, result));

// handle the fifth condition : x=0 && y<0
auto cond4 = context.mark_node(std::make_shared<v1::LogicalAnd>(is_x_zero, is_y_negative));
auto const_minus_two = v0::Constant::create(element::f32, Shape{}, {-2});
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
const_minus_two = context.mark_node(std::make_shared<v1::ConvertLike>(const_minus_two, const_pi));
auto pi_div_minus_two = context.mark_node(std::make_shared<v1::Divide>(const_pi, const_minus_two));
result = context.mark_node(std::make_shared<v1::Select>(cond4, pi_div_two, result));

// check whether out tensor is given
if (!context.input_is_none(2) && context.get_input_size() == 3) {
context.mutate_input(2, result);
}
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
rghvsh marked this conversation as resolved.
Show resolved Hide resolved

// when out tensor is not in input
return {result};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ OP_CONVERTER(translate_argmax);
OP_CONVERTER(translate_argmin);
OP_CONVERTER(translate_as_strided);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_atan2);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_bool);
OP_CONVERTER(translate_batch_norm);
Expand Down Expand Up @@ -359,6 +360,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::asinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Asinh>>},
{"aten::atan", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>, 1>},
{"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atan>>},
{"aten::atan2", op::optional_out<op::translate_atan2, 2>},
{"aten::atanh",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>, 1>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
Expand Down
69 changes: 69 additions & 0 deletions tests/layer_tests/pytorch_tests/test_atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0i

import pytest
import torch
import math
import numpy as np

from pytorch_layer_test_class import PytorchLayerTest

class TestAtan2(PytorchLayerTest):
def _prepare_input(self, y, x, dtype1=None, dtype2=None):
inputs = [np.array(y).astype(dtype1), np.array(x).astype(dtype2)]
return inputs

def create_model(self, dtype1=None, dtype2=None use_out=False):
rghvsh marked this conversation as resolved.
Show resolved Hide resolved
dtype_map = {
"float32": torch.float32,
"float64": torch.float64,
"int64": torch.int64,
"int32": torch.int32,
"int16": torch.int16,
"uint8": torch.uint8,
"int8": torch.int8,
}

class aten_atan2_out(torch.nn.Module):
def __init__(self, out) -> None:
super().__init__()
self.out = torch.empty(25, dtype=out)

def forward(self, y, x):
return torch.atan2(input = y, other = x, out=self.out)

class aten_atan2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, y, x):
return torch.atan2(input = y, other = x)

dtype1 = dtype_map.get(dtype1)
dtype2 = dtype_map.get(dtype2)

if use_out:
model_class = aten_atan2_out(dtype1)
else:
model_class = aten_atan2()


ref_net = None

return model_class, ref_net, "aten::atan2"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dtype1, dtype2", [(None, None), ("float32", "int32"), ("float64", "float64"), ("int32", "float64"), ("int64", "int16"), ("int8", "int8"), ("uint8", "uint8")])
@pytest.mark.parametrize(
"y, x", [(0, 1.5), (0, 0), (1.25, -5), (1, 10), (-1, -5.5), (-1, -5), (1.25, -5.5), (1.9, 2.9), [10, 9.9]]
)
@pytest.mark.parametrize("use_out", [False, True])
def test_atan2_with_out(self, dtype1, dtype2, use_out, y, x, ie_device, precision, ir_version):
self._test(
*self.create_model(dtype=dtype, use_out=use_out),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"y": y, "x": x}
)
Loading