diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index 80d8ef23b459..f575bc8e7ce2 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -54,7 +54,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` quantize foreach while_loop - ifelse + condition ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 96ce7987d800..69d38beffd1c 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -54,7 +54,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` quantize foreach while_loop - ifelse + condition ``` ## API Reference diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 12407cf4fe74..4e30e5e16e53 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -28,7 +28,7 @@ except ImportError: pass -__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"] # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): @@ -363,7 +363,7 @@ def _func_wrapper(loop_vars): )) return stacked_outputs, list(loop_vars) -def ifelse(cond, then_func, else_func, inputs): +def condition(cond_func, then_func, else_func, inputs): # pylint: disable=redefined-outer-name """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of @@ -371,11 +371,11 @@ def ifelse(cond, then_func, else_func, inputs): `inputs` is a list of NDArrays on which the condition and computations rely on. - `cond` is a user-defined function, used as the if condition. + `cond_func` is a user-defined function, used as the if condition. It consumes `inputs`, and produces a scalar MXNet NDArray, indicating which branch of computation should be used. - The `cond` is variadic, and its signature should be - `cond(*loop_vars) => NDArray`. + The `cond_func` is variadic, and its signature should be + `cond_func(*loop_vars) => NDArray`. `then_func` is a user-defined function, used as computation of the then branch. It consumes `inputs`, and produces `outputs`. @@ -394,14 +394,14 @@ def ifelse(cond, then_func, else_func, inputs): Parameters ---------- - cond: a Python function. + cond_func: a Python function. The branch condition. then_func: a Python function. - The computation to be executed if `cond` is true. + The computation to be executed if `cond_func` is true. else_func: a Python function. - The computation to be executed if `cond` is false. + The computation to be executed if `cond_func` is false. inputs: list of NDArrays. - The variables fed to `cond`, `then_func` and `else_func`. + The variables fed to `cond_func`, `then_func` and `else_func`. Returns ------- @@ -409,11 +409,11 @@ def ifelse(cond, then_func, else_func, inputs): Examples -------- - >>> cond = lambda a, b: a * b < 5 + >>> cond_func = lambda a, b: a * b < 5 >>> then_func = lambda a, b: (a + 5) * (b + 5) >>> else_func = lambda a, b: (a - 5) * (b - 5) >>> inputs = (mx.nd.array([1]), mx.nd.array([2])) - >>> outputs = mx.nd.contrib.ifelse(cond, then_func, else_func, inputs) + >>> outputs = mx.nd.contrib.cond(cond_func, then_func, else_func, inputs) >>> outputs[0] [42.] @@ -448,7 +448,7 @@ def _to_ndarray_tuple(inputs, name): inputs = _to_ndarray_tuple(inputs, "inputs") if len(inputs) == 0: raise ValueError("inputs should contain at least one element") - branch = _to_python_scalar(cond(*inputs), bool, "Return value of cond") + branch = _to_python_scalar(cond_func(*inputs), bool, "Return value of cond_func") if branch: outputs = then_func(*inputs) outputs = _to_ndarray_tuple(outputs, "outputs of then_func") diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 13bb89e8d9f1..3274b7833c47 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -34,7 +34,7 @@ from ..base import SymbolHandle, _as_list from ..attribute import AttrScope -__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -557,7 +557,7 @@ def _union_inputs(*graphs): final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] return outputs, final_loop_vars -def ifelse(cond, then_func, else_func, inputs, name="ifelse"): +def condition(cond_func, then_func, else_func, inputs, name="cond"): # pylint: disable=redefined-outer-name """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of @@ -565,11 +565,11 @@ def ifelse(cond, then_func, else_func, inputs, name="ifelse"): `inputs` is a list of Symbols on which the condition and computations rely on. - `cond` is a user-defined function, used as the if condition. + `cond_func` is a user-defined function, used as the if condition. It consumes `inputs`, and produces a scalar MXNet symbol, indicating which branch of computation should be used. - The `cond` is variadic, and its signature should be - `cond(*loop_vars) => Symbol`. + The `cond_func` is variadic, and its signature should be + `cond_func(*loop_vars) => Symbol`. `then_func` is a user-defined function, used as computation of the then branch. It consumes `inputs`, and produces `outputs`. @@ -588,14 +588,14 @@ def ifelse(cond, then_func, else_func, inputs, name="ifelse"): Parameters ---------- - cond: a Python function. + cond_func: a Python function. The branch condition. then_func: a Python function. - The computation to be executed if `cond` is true. + The computation to be executed if `cond_func` is true. else_func: a Python function. - The computation to be executed if `cond` is false. + The computation to be executed if `cond_func` is false. inputs: list of Symbols. - The variables fed to `cond`, `then_func` and `else_func`. + The variables fed to `cond_func`, `then_func` and `else_func`. Returns ------- @@ -603,11 +603,11 @@ def ifelse(cond, then_func, else_func, inputs, name="ifelse"): Examples -------- - >>> cond = lambda a, b: a * b < 5 + >>> cond_func = lambda a, b: a * b < 5 >>> then_func = lambda a, b: (a + 5) * (b + 5) >>> else_func = lambda a, b: (a - 5) * (b - 5) >>> inputs = (mx.sym.var('a'), mx.sym.var('b')) - >>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs) + >>> outputs = mx.sym.contrib.cond(cond_func, then_func, else_func, inputs) """ def _to_symbol_tuple(inputs, name): """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, @@ -681,10 +681,10 @@ def _union_inputs(*graphs): inputs = _to_symbol_tuple(inputs, "inputs") if len(inputs) == 0: raise ValueError("loop_vars should contain at least one element") - # create graph for `cond' - cond_g, num_outputs = _create_subgraph(inputs, cond, name + "_cond") - if num_outputs != 1: - raise ValueError("cond should always produce a single output") + # create graph for `cond_func' + cond_g, cond_num_outputs = _create_subgraph(inputs, cond_func, name + "_cond") + if cond_num_outputs != 1: + raise ValueError("cond_func should always produce a single output") # create graph for `then` then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + "_then") # create graph for `else` @@ -694,7 +694,7 @@ def _union_inputs(*graphs): # find symbols used in either cond_g or func_g input_syms, (cond_input_locs, then_input_locs, else_input_locs) = \ _union_inputs(cond_g, then_g, else_g) - result = symbol._internal._ifelse( + result = symbol._internal._cond( # [cond, then_g, else_g, *input_syms] cond_g, then_g, diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 5159a27cb508..7c1beccb0288 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -913,13 +913,13 @@ WhileLoopGradient(const nnvm::NodePtr& n, const std::vector& og return entries; } -struct IfelseParam : public dmlc::Parameter { +struct CondParam : public dmlc::Parameter { int num_args; int num_outputs; nnvm::Tuple cond_input_locs; nnvm::Tuple then_input_locs; nnvm::Tuple else_input_locs; - DMLC_DECLARE_PARAMETER(IfelseParam) { + DMLC_DECLARE_PARAMETER(CondParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(3) .describe("Number of input arguments, including cond, then and else as three symbol inputs."); DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) @@ -931,42 +931,42 @@ struct IfelseParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(else_input_locs) .describe("The locations of else's inputs in the given inputs."); } -}; // struct IfelseParam +}; // struct CondParam -DMLC_REGISTER_PARAMETER(IfelseParam); +DMLC_REGISTER_PARAMETER(CondParam); -class IfelseState { +class CondState { public: - IfelseParam params; + CondParam params; CachedOpPtr cond_op; LoopState then_branch; LoopState else_branch; int branch_selection; // 1 if then branch; 0 if else branch; -1 if undefined - IfelseState(const IfelseParam ¶ms, - const Symbol &cond, - const Symbol &then_sym, - const Symbol &else_sym): - params(params), - cond_op(LoopState::MakeSharedOp(cond)), - then_branch(then_sym), - else_branch(else_sym), - branch_selection(-1) { + CondState(const CondParam ¶ms, + const Symbol &cond, + const Symbol &then_sym, + const Symbol &else_sym): + params(params), + cond_op(LoopState::MakeSharedOp(cond)), + then_branch(then_sym), + else_branch(else_sym), + branch_selection(-1) { } }; -static void IfelseComputeExCPU(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +static void CondComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { // The argument `inputs' are loop_vars and other inputs // loop_vars are stored in stored in `loop_vars_locs' // The argument `outputs' are output and new_loop_vars // [0: num_out_data) are outputs at each step. // [num_out_data: ) are new_loop_vars - IfelseState &state = state_ptr.get_state(); - const IfelseParam& params = state.params; + CondState &state = state_ptr.get_state(); + const CondParam& params = state.params; // a helper function, converting std::vector to std::vector const auto to_ptr_vec = [](std::vector &in, std::vector *out) { out->clear(); @@ -1005,13 +1005,13 @@ static void IfelseComputeExCPU(const OpStatePtr& state_ptr, loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad); } -static void IfelseGradComputeExCPU(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& _req, - const std::vector& outputs) { - IfelseState &state = state_ptr.get_state(); - const IfelseParam& params = state.params; +static void CondGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& _req, + const std::vector& outputs) { + CondState &state = state_ptr.get_state(); + const CondParam& params = state.params; // sanity checks CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args); CHECK_EQ(outputs.size(), _req.size()); @@ -1034,11 +1034,11 @@ static void IfelseGradComputeExCPU(const OpStatePtr& state_ptr, loop_state.Cleanup(); } -static bool IfelseShape(const nnvm::NodeAttrs& attrs, - std::vector *in_shape, - std::vector *out_shape) { +static bool CondShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { using nnvm::ShapeVector; - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); static const std::function is_udf = is_shape_udf; // sanity checks CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args); @@ -1121,10 +1121,10 @@ static bool IfelseShape(const nnvm::NodeAttrs& attrs, return succ_0 && succ_1 && succ_2; } -static bool IfelseType(const nnvm::NodeAttrs& attrs, - std::vector *in_type, - std::vector *out_type) { - const IfelseParam& params = nnvm::get(attrs.parsed); +static bool CondType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const CondParam& params = nnvm::get(attrs.parsed); static const std::function is_udf = is_type_udf; CHECK_EQ(in_type->size() + 3U, (size_t) params.num_args); CHECK_EQ(out_type->size(), (size_t) params.num_outputs); @@ -1147,12 +1147,12 @@ static bool IfelseType(const nnvm::NodeAttrs& attrs, return succ_0 && succ_1 && succ_2; } -static bool IfelseStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); +static bool CondStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const CondParam& params = nnvm::get(attrs.parsed); static const std::function is_udf = is_stype_udf; CHECK_EQ(in_attrs->size() + 3U, (size_t) params.num_args); CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); @@ -1182,12 +1182,12 @@ static bool IfelseStorageType(const nnvm::NodeAttrs& attrs, return succ_0 && succ_1 && succ_2; } -static bool BackwardIfelseStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); +static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const CondParam& params = nnvm::get(attrs.parsed); CHECK_EQ(out_attrs->size() + 3U, (size_t) params.num_args); CHECK_EQ(attrs.subgraphs.size(), 3U); static const std::function is_udf = is_stype_udf; @@ -1230,12 +1230,12 @@ static bool BackwardIfelseStorageType(const nnvm::NodeAttrs& attrs, return succ_0 && succ_1; } -static OpStatePtr CreateIfelseState(const NodeAttrs& attrs, - Context ctx, - const std::vector& ishape, - const std::vector& itype) { - const IfelseParam& params = nnvm::get(attrs.parsed); - return OpStatePtr::Create( +static OpStatePtr CreateCondState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const CondParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create( params, *attrs.subgraphs[0], *attrs.subgraphs[1], @@ -1243,8 +1243,8 @@ static OpStatePtr CreateIfelseState(const NodeAttrs& attrs, } static std::vector -IfelseGradient(const nnvm::NodePtr& n, const std::vector& ograds) { - ElemwiseGradUseInOut fgrad{"_backward_ifelse"}; +CondGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_cond"}; std::vector entries = fgrad(n, ograds); entries[0].node->attrs.subgraphs = n->attrs.subgraphs; return entries; @@ -1373,21 +1373,21 @@ NNVM_REGISTER_OP(_backward_while_loop) .set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU) .set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU); -NNVM_REGISTER_OP(_ifelse) +NNVM_REGISTER_OP(_cond) .MXNET_DESCRIBE("Run a if-then-else using user-defined condition and computation") -.set_attr_parser(ParamParser) -.set_attr("FInferStorageType", IfelseStorageType) +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", CondStorageType) .set_num_inputs([](const NodeAttrs& attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); return params.num_args; }) .set_num_outputs([](const NodeAttrs& attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); return params.num_outputs; }) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); std::vector names; names.reserve(params.num_args); names.push_back("cond"); @@ -1401,40 +1401,40 @@ NNVM_REGISTER_OP(_ifelse) [](const NodeAttrs& attrs) { return std::vector{0, 1, 2}; }) -.set_attr("FGradient", IfelseGradient) -.set_attr("FCreateOpState", CreateIfelseState) -.set_attr("FInferShape", IfelseShape) -.set_attr("FInferType", IfelseType) -.set_attr("FStatefulComputeEx", IfelseComputeExCPU) +.set_attr("FGradient", CondGradient) +.set_attr("FCreateOpState", CreateCondState) +.set_attr("FInferShape", CondShape) +.set_attr("FInferType", CondType) +.set_attr("FStatefulComputeEx", CondComputeExCPU) .set_attr("FExecType", [](const NodeAttrs& attrs) { return ExecType::kSubgraphExec; }) -.set_attr("FStatefulComputeEx", IfelseComputeExCPU) +.set_attr("FStatefulComputeEx", CondComputeExCPU) .set_attr("key_var_num_args", "num_args") .add_argument("cond", "Symbol", "Input graph for the condition.") .add_argument("then_branch", "Symbol", "Input graph for the then branch.") .add_argument("else_branch", "Symbol", "Input graph for the else branch.") .add_argument("data", "NDArray-or-Symbol[]", "The input arrays that include data arrays and states.") -.add_arguments(IfelseParam::__FIELDS__()); +.add_arguments(CondParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_ifelse) +NNVM_REGISTER_OP(_backward_cond) .set_num_inputs([](const NodeAttrs& attrs){ - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); return params.num_outputs * 2 + params.num_args - 3; }) .set_num_outputs([](const NodeAttrs& attrs){ - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); return params.num_args - 3; }) .set_attr("FExecType", [](const NodeAttrs& attrs) { return ExecType::kSubgraphExec; }) -.set_attr("FInferStorageType", BackwardIfelseStorageType) -.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", BackwardCondStorageType) +.set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) -.set_attr("FStatefulComputeEx", IfelseGradComputeExCPU) -.set_attr("FStatefulComputeEx", IfelseGradComputeExCPU); +.set_attr("FStatefulComputeEx", CondGradComputeExCPU) +.set_attr("FStatefulComputeEx", CondGradComputeExCPU); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 12694572bb7c..87eac8960339 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -974,7 +974,7 @@ def _func(*states): y = y.asnumpy() assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) -def _verify_ifelse(cond, then_func, else_func, input_var_shapes, free_var_shapes, is_train): +def _verify_cond(cond_func, then_func, else_func, input_var_shapes, free_var_shapes, is_train): def _create_symbol(prefix, i): return mx.sym.var(prefix + str(i)) @@ -1008,8 +1008,8 @@ def _get_imperative_result(): for var in free_vars + input_vars: var.attach_grad() with mx.autograd.record(train_mode=is_train): - outputs = mx.nd.contrib.ifelse( - cond=lambda *__input_vars: cond(__input_vars, free_vars), + outputs = mx.nd.contrib.condition( + cond_func=lambda *__input_vars: cond_func(__input_vars, free_vars), then_func=lambda *__input_vars: then_func(__input_vars, free_vars), else_func=lambda *__input_vars: else_func(__input_vars, free_vars), inputs=input_vars, @@ -1025,8 +1025,8 @@ def _get_imperative_result(): return _to_numpy_list(outputs), _to_numpy_list(grads), out_grads def _get_symbolic_result(out_grads): - outputs_sym = mx.sym.contrib.ifelse( - cond=lambda *__loop_vars: cond(__loop_vars, _free_syms), + outputs_sym = mx.sym.contrib.condition( + cond_func=lambda *__loop_vars: cond_func(__loop_vars, _free_syms), then_func=lambda *__loop_vars: then_func(__loop_vars, _free_syms), else_func=lambda *__loop_vars: else_func(__loop_vars, _free_syms), inputs=_input_syms, @@ -1062,7 +1062,7 @@ def _get_symbolic_result(out_grads): @with_seed() -def test_ifelse(): +def test_cond(): # whether there are free variables in three graphs # whether these three graphs contain input_vars # whether to use all input_vars @@ -1080,8 +1080,8 @@ def cond(inputs, free): return cond for is_train in [True, False]: for is_inverse in [False, True]: - _verify_ifelse( - cond=make_cond(is_inverse), + _verify_cond( + cond_func=make_cond(is_inverse), then_func=then_func, else_func=else_func, is_train=is_train,