-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/master' into azure-python
- Loading branch information
Showing
67 changed files
with
1,663 additions
and
1,696 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import openvino.runtime.opset13 as ops | ||
from openvino.runtime import PartialShape, Dimension, Type | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("probs_shape", "num_samples_shape", "convert_type", "with_replacement", "log_probs", "global_seed", "op_seed", "expected_out_shape"), | ||
[ | ||
([4, 16], [], "i32", False, True, 7461, 1546, PartialShape([4, -1])), | ||
([8], [1], "i64", True, False, 0, 0, PartialShape([-1])), | ||
], | ||
) | ||
def test_multinomial_param_inputs(probs_shape, num_samples_shape, convert_type, with_replacement, log_probs, global_seed, op_seed, expected_out_shape): | ||
probs = ops.parameter(probs_shape, dtype=np.float32) | ||
num_samples = ops.parameter(num_samples_shape, dtype=np.int32) | ||
|
||
op = ops.multinomial(probs, num_samples, | ||
convert_type=convert_type, | ||
with_replacement=with_replacement, | ||
log_probs=log_probs, | ||
global_seed=global_seed, | ||
op_seed=op_seed) | ||
assert op.get_output_size() == 1 | ||
assert op.get_type_name() == "Multinomial" | ||
assert op.get_output_element_type(0) == Type.i32 if convert_type == "i32" else Type.i64 | ||
assert op.get_output_partial_shape(0) == expected_out_shape | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("probs_array", "num_samples_val", "convert_type", "with_replacement", "log_probs", "global_seed", "op_seed", "expected_out_shape"), | ||
[ | ||
(np.array([0.7, 0.3, 0.6, 0.5]), 3, "i32", False, True, 111, 222, PartialShape([3])), | ||
(np.array([[0.7, 0.3], [0.6, 0.5]]), 2, "i64", True, False, 111, 222, PartialShape([2, 2])), | ||
], | ||
) | ||
def test_multinomial_const_inputs(probs_array, num_samples_val, convert_type, with_replacement, log_probs, global_seed, op_seed, expected_out_shape): | ||
probs = ops.constant(probs_array, dtype=np.float32) | ||
num_samples = ops.constant(num_samples_val, dtype=np.int32) | ||
|
||
op = ops.multinomial(probs, num_samples, | ||
convert_type=convert_type, | ||
with_replacement=with_replacement, | ||
log_probs=log_probs, | ||
global_seed=global_seed, | ||
op_seed=op_seed) | ||
|
||
assert op.get_output_size() == 1 | ||
assert op.get_type_name() == "Multinomial" | ||
assert op.get_output_element_type(0) == Type.i32 if convert_type == "i32" else Type.i64 | ||
assert op.get_output_partial_shape(0) == expected_out_shape | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("probs_shape", "num_samples_shape", "convert_type", "with_replacement", "log_probs", "expected_out_shape"), | ||
[ | ||
([10], [1], "i32", True, True, PartialShape([-1])), | ||
([2, 16], [], "i64", False, False, PartialShape([2, -1])), | ||
], | ||
) | ||
def test_multinomial_default_attrs(probs_shape, num_samples_shape, convert_type, with_replacement, log_probs, expected_out_shape): | ||
probs = ops.parameter(probs_shape, dtype=np.float32) | ||
num_samples = ops.parameter(num_samples_shape, dtype=np.int32) | ||
|
||
op = ops.multinomial(probs, num_samples, | ||
convert_type=convert_type, | ||
with_replacement=with_replacement, | ||
log_probs=log_probs) | ||
|
||
assert op.get_output_size() == 1 | ||
assert op.get_type_name() == "Multinomial" | ||
assert op.get_output_element_type(0) == Type.i32 if convert_type == "i32" else Type.i64 | ||
assert op.get_output_partial_shape(0) == expected_out_shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.