Skip to content

Commit

Permalink
Fixed seed interpretation in random_uniform* operations, updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gkrivor committed Jun 21, 2024
1 parent 3c0480d commit 1d8980d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
6 changes: 2 additions & 4 deletions src/frontends/onnx/frontend/src/op/random_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,18 @@ ov::OutputVector random_uniform(const ov::frontend::onnx::Node& node) {
static_cast<int64_t>(TensorProto_DataType::TensorProto_DataType_FLOAT));
const auto high_const = node.get_attribute_as_constant<float>("high", 1.0f);
const auto low_const = node.get_attribute_as_constant<float>("low", 0.0f);
const auto seed = node.get_attribute_value<float>("seed", 0.0f);
const auto seed = common::convert_float_seed(node.get_attribute_value<float>("seed", 0.0f));
const auto target_shape_const = node.get_attribute_as_constant<std::vector<int64_t>>("shape");

const auto target_type = common::get_ov_element_type(dtype);
const uint64_t global_seed = 0;
// TODO: This multiplication leads to a mismatch in accuracy. Issue: 123003
const auto seed_uint64 = static_cast<uint64_t>(seed * 1000);

return {std::make_shared<v8::RandomUniform>(target_shape_const,
low_const,
high_const,
target_type,
global_seed,
seed_uint64)};
seed)};
}

} // namespace set_1
Expand Down
5 changes: 2 additions & 3 deletions src/frontends/onnx/frontend/src/op/random_uniform_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,16 @@ ov::OutputVector random_uniform_like(const ov::frontend::onnx::Node& node) {

const auto high_const = node.get_attribute_as_constant<float>("high", 1.0f);
const auto low_const = node.get_attribute_as_constant<float>("low", 0.0f);
const auto seed = node.get_attribute_value<float>("seed", 0.f);
const auto seed = common::convert_float_seed(node.get_attribute_value<float>("seed", 0.f));

const uint64_t global_seed = 0;
const auto seed_uint64 = static_cast<uint64_t>(seed * 1000);

return {std::make_shared<v8::RandomUniform>(target_shape,
low_const,
high_const,
target_type,
global_seed,
seed_uint64)};
seed)};
}

} // namespace set_1
Expand Down
8 changes: 4 additions & 4 deletions src/frontends/onnx/tests/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4949,7 +4949,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_random_uniform) {
const auto model = convert_model("random_uniform.onnx");

auto test_case = ov::test::TestCase(model, s_device);
test_case.add_expected_output<float>(Shape{2, 2}, {43.45518f, 48.67585f, 42.227386f, 40.86294f});
test_case.add_expected_output<float>(Shape{2, 2}, {43.70129f, 45.26042f, 43.48503f, 46.43743f});
test_case.run();
}

Expand All @@ -4958,15 +4958,15 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_random_uniform_like) {

auto test_case = ov::test::TestCase(model, s_device);
test_case.add_input<float>(Shape{2, 2}, {41, 42, 43, 44});
test_case.add_expected_output<float>(Shape{2, 2}, {43.45518f, 48.67585f, 42.227386f, 40.86294f});
test_case.add_expected_output<float>(Shape{2, 2}, {43.70129f, 45.26042f, 43.48503f, 46.43743f});
test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_random_normal) {
const auto model = convert_model("random_normal.onnx");

auto test_case = ov::test::TestCase(model, s_device);
test_case.add_expected_output<float>(Shape{2, 2}, {83.052017f, 55.496368f, 119.31188f, -3.6946249f});
test_case.add_expected_output<float>(Shape{2, 2}, {30.357481f, 72.41268f, 12.999034f, 70.04985f});
test_case.run();
}

Expand All @@ -4975,7 +4975,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_random_normal_like) {

auto test_case = ov::test::TestCase(model, s_device);
test_case.add_input<float>(Shape{2, 2}, {0, 0, 0, 0});
test_case.add_expected_output<float>(Shape{2, 2}, {83.052017f, 55.496368f, 119.31188f, -3.6946249f});
test_case.add_expected_output<float>(Shape{2, 2}, {30.357481f, 72.41268f, 12.999034f, 70.04985f});
test_case.run();
}

Expand Down

0 comments on commit 1d8980d

Please sign in to comment.