Skip to content

Commit

Permalink
[TF FE] Deduce Switch-Merge predicate shape (openvinotoolkit#27277)
Browse files Browse the repository at this point in the history
**Details:** It helps to convert some TF models out-of-the-box with
static rank tensors that are required by plugins for inference.

**Ticket:** 156204

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
  • Loading branch information
rkazants authored Oct 29, 2024
1 parent c158480 commit e07546d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ bool ov::pass::ReverseShapeAndTypeInfer::run_on_model(const std::shared_ptr<ov::
if_op->get_input_tensor(0).m_element_type = element::boolean;
is_changed = true;
}

// in case TensorFlow models, we can deduce predicate shape that must be a scalar
// If operations created by fusing Switch-Merge sub-graph contain tf_switch_merge_if rt-info
if (if_op->get_rt_info().count("tf_switch_merge_if") &&
if_op->get_rt_info()["tf_switch_merge_if"].as<bool>() &&
if_op->input_value(0).get_partial_shape().rank().is_dynamic()) {
if_op->get_input_tensor(0).m_partial_shape = ov::PartialShape({});
is_changed = true;
}
} else if (ov::as_type_ptr<ov::op::v1::ConvertLike>(op)) {
is_changed |= inherit_output_shape(op, {0});
is_changed |= inherit_output_type(op, {1});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ bool pass::SwitchMergeResolver::run_on_model(const shared_ptr<Model>& m) {
auto else_body = make_shared<Model>(else_results, else_params);

auto if_op = make_shared<v8::If>(cond);
// in case TensorFlow models, we can deduce predicate shape that must be a scalar
if_op->get_rt_info()["tf_switch_merge_if"] = true;

set_cf_marker(if_cf_marker, if_op);
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
Expand Down
45 changes: 45 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_SwitchMerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,48 @@ def test_merge_eliminating_several_cond_flows(self, params, cond_value, x_type,
self._test(*self.merge_eliminating_several_cond_flows_net(**params, cond_value=cond_value, x_type=x_type),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)


class TestSwitchMergeWithVariablePredicate(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'x:0' in inputs_info
x_shape = inputs_info['x:0']
inputs_data = {}
rng = np.random.default_rng()
inputs_data['x:0'] = rng.integers(-10, 10, x_shape).astype(np.float32)
inputs_data['cond:0'] = np.array(self.cond_value, dtype=bool)
return inputs_data

def switch_merge_with_variable_predicate_net(self, x_shape, cond_shape, cond_value):
self.cond_value = cond_value
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
cond = tf.compat.v1.placeholder(tf.bool, cond_shape, 'cond')
const_add = tf.constant(3, dtype=tf.float32)
const_sub = tf.constant(1, dtype=tf.float32)
switch_false, switch_true = tf.raw_ops.Switch(data=x, pred=cond)
add = tf.raw_ops.AddV2(x=switch_false, y=const_add)
sub = tf.raw_ops.Sub(x=switch_true, y=const_sub)
merge = tf.raw_ops.Merge(inputs=[add, sub])
const_main = tf.constant(1, dtype=tf.float32)
tf.raw_ops.AddV2(x=merge[0], y=const_main, name='add_res')
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

@pytest.mark.parametrize('x_shape', [[], [2], [3, 2]])
@pytest.mark.parametrize('cond_shape', [None, []])
@pytest.mark.parametrize('cond_value', [True, False])
@pytest.mark.precommit
@pytest.mark.nightly
def test_switch_merge_with_variable_predicate(self, x_shape, cond_shape, cond_value,
ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
if ie_device == 'GPU':
pytest.skip("156244: accuracy error on GPU")
self._test(*self.switch_merge_with_variable_predicate_net(x_shape, cond_shape, cond_value),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit e07546d

Please sign in to comment.