Skip to content

Commit

Permalink
fix aten::empty_like dtype (openvinotoolkit#22584)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Feb 1, 2024
1 parent 2749c3a commit 6b010ad
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ OutputVector translate_empty_like(const NodeContext& context) {
if (!context.input_is_none(dtype_id)) {
empty = base_translate_full_with_convert(context, sizes, value, dtype_id);
} else {
empty = base_translate_full(context, sizes, value);
empty = base_translate_full_with_convertlike(context, sizes, value, input);
}
} else if (context.get_input_size() == 4) {
auto out = context.input_is_none(3) ? input : context.get_input(3);
Expand Down
26 changes: 22 additions & 4 deletions tests/layer_tests/pytorch_tests/test_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def _prepare_input(self, shape, dtype=np.float32, out=False):
return (np.random.randn(*shape).astype(dtype if dtype is not None else np.float32),)
return (np.random.randn(*shape), np.ones(shape, dtype=(dtype if dtype is not None else np.float32)))

def create_model(self, dtype, out):
def create_model(self, dtype, out, no_expose_dtype=False):

class aten_empty_like(torch.nn.Module):

def __init__(self, dtype=None, out=False):
def __init__(self, dtype=None, out=False, no_expose_dtype=False):
dtype_map = {
"float32": torch.float32,
"float64": torch.float64,
Expand All @@ -72,6 +72,8 @@ def __init__(self, dtype=None, out=False):
self.dtype = dtype_map.get(dtype, None)
if out:
self.forward = self.forward_out
if no_expose_dtype:
self.forward = self.forward_input_dtype

def forward(self, input_tensor):
empty = torch.empty_like(input_tensor, dtype=self.dtype)
Expand All @@ -80,6 +82,14 @@ def forward(self, input_tensor):
# produce sporadic errors if nan would be in empty.
return torch.zeros_like(empty)

def forward_input_dtype(self, input_tensor):
# We don't want to compare values, just shape and type,
# so we call zeros_like on data. Multiplying by zero would
# produce sporadic errors if nan would be in empty.
input_tensor.to(self.dtype)
empty = torch.empty_like(input_tensor)
return torch.zeros_like(empty)

def forward_out(self, input_tensor, out_tensor):
torch.empty_like(input_tensor, out=out_tensor)
# We don't want to compare values, just shape and type,
Expand All @@ -89,17 +99,25 @@ def forward_out(self, input_tensor, out_tensor):

ref_net = None

return aten_empty_like(dtype, out), ref_net, "aten::empty_like"
return aten_empty_like(dtype, out, no_expose_dtype), ref_net, "aten::empty_like"

@pytest.mark.parametrize('dtype', (None, "float32", "float64", "int64", "int32", "uint8", "int8"))
@pytest.mark.parametrize("input_shape", [[2,], [1, 10], [10, 5, 2]])
@pytest.mark.parametrize("out", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_empty(self, ie_device, precision, ir_version, dtype, input_shape, out):
def test_empty_like(self, ie_device, precision, ir_version, dtype, input_shape, out):
self._test(*self.create_model(dtype, out), ie_device, precision, ir_version,
kwargs_to_prepare_input={"shape": input_shape, "out": out, "dtype": dtype})

@pytest.mark.parametrize('dtype', (None, "float32", "float64", "int64", "int32", "uint8", "int8"))
@pytest.mark.parametrize("input_shape", [[2,], [1, 10], [10, 5, 2]])
@pytest.mark.nightly
@pytest.mark.precommit
def test_empty_like_no_dtype(self, ie_device, precision, ir_version, dtype, input_shape):
self._test(*self.create_model(dtype, out=False, no_expose_dtype=True), ie_device, precision, ir_version,
kwargs_to_prepare_input={"shape": input_shape, "out": False, "dtype": dtype})


class TestEmptyBoolean(PytorchLayerTest):

Expand Down

0 comments on commit 6b010ad

Please sign in to comment.