diff --git a/3rdparty/tvm b/3rdparty/tvm index 6ab4da678341..290226e1c9ad 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6ab4da6783417d8afdeb6b0426b44959b2afc709 +Subproject commit 290226e1c9adbb3e598f9ed9184018df1c12be33 diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b1f065e9f822..fcfafb3be2f8 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -191,3 +191,130 @@ def check_input(inputs, in_type, msg): if not_data_list and len(outputs) == 1: outputs = outputs[0] return (outputs, states) + + +def while_loop(loop_vars, cond, func, max_iterations): + """Run a while loop with user-defined computation and loop condition. + + This operator simulates a while loop which iterately does customized computation + as long as the condition is satisfied. + + `loop_vars` is a list of NDArrays on which the computation uses. + + `cond` is a user-defined function as the loop condition. + It consumes `loop_vars`, and produces a scalar MXNet NDArray, + indicating the termination of the loop. + The loop ends when `cond` returns false (zero). + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => NDArray`. + + `func` is a user-defined function as the loop body. + It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. + The number of elements, shape, dtype of each element in `step_output` should be consistent. + The `new_loop_vars` should be consistent with `loop_vars` on each step. + The `func` is variadic, and its signature should be + `cond(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. + + `max_iterations` is a scalar that defines the maximum number of iterations allowed. + + This function returns a list of NDArrays of length `|step_output| + |loop_vars|`. + The i-th element in the first `|step_output|` ones of the list represent + the i-th `step_output` at all step, stacked along axis 0. + The i-th element in the last `|loop_vars|` ones of the list + represent the final state of each loop variable. + + Warning: when `cond` is never satisfied, we assume `step_output` is empty. + TODO(Junru): the output shape along axis 0 is not consistent to the symbloic version. + Should we mention this in our doc? + + Parameters + ---------- + loop_vars: list of NDArrays. + The initial values of the loop variables. + cond: a Python function. + The loop condition. + func: a Python function. + The loop body. + max_iteration: a python int. + Maximum number of iterations. + + Returns + ------- + outputs: a list of NDArrays of length `|step_output| + |loop_vars|`. + The first `|step_output|` NDArrays are outputs. + The last `|loop_vars|` NDArrays are the final state of loop variables. + TODO(Junru): change the output format + + Examples + -------- + TODO(Junru): run this + >>> cond = lambda i, s: i <= 5 + >>> func = lambda i, s: (i + 1, s + i) + >>> loop_vars = (mx.nd.array([1], dtype="int64"), mx.nd.array([0], dtype="int64")) + >>> outputs = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10) + """ + def _to_python_scalar(inputs, type, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if isinstance(inputs, ndarray.NDArray): + inputs = inputs.asscalar() + try: + inputs = type(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type.__name__)) + return inputs + + def _to_ndarray_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray, + a tuple of mxnet NDArray, into a tuple of NDArray + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, ndarray.NDArray): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + for item in inputs: + if not isinstance(item, ndarray.NDArray): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + return inputs + + def _func_wrapper(loop_vars): + """This wrapper unifies + "func: loop_vars -> new_loop_vars" + and "func: loop_vars -> (step_output, new_loop_vars)" + into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars) + """ + step_output, new_loop_vars = func(*loop_vars) + if step_output is None: + step_output = [] + if new_loop_vars is None: + new_loop_vars = [] + step_output = _to_ndarray_tuple(step_output, "step_output") + new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars") + if len(loop_vars) != len(new_loop_vars): + raise ValueError("The length of loop_vars should be consistent during the loop") + return step_output, new_loop_vars + + max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") + loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars") + # It should be work as fine if loop_vars are empty I guess, + # but it is semantically unnecessary to include this case. + if len(loop_vars) == 0: + raise ValueError("loop_vars should contain at least one element") + + steps = 0 + outputs = [] + while steps < max_iterations and \ + _to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition + step_output, loop_vars = _func_wrapper(loop_vars) + outputs.append(step_output) + steps += 1 + if len(outputs) != steps or len(step_output) != len(outputs[0]): + raise ValueError("step_output are inconsistent on each step") + try: + outputs = list(ndarray.op.stack(*item) for item in zip(*outputs)) + except ValueError: + raise ValueError("step_outputs are inconsistent on each step") + return outputs, list(loop_vars) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 28bb507dd13d..bf1ec52e3657 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -336,3 +336,204 @@ def check_data(inputs, in_type, msg): states = states[0] return (outs, states) + +def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): + """Run a while loop with user-defined computation and loop condition. + + This operator simulates a while loop which iterately does customized computation + as long as the condition is satisfied. + + `loop_vars` is a list of Symbols on which the computation uses. + + `cond` is a user-defined function as the loop condition. + It consumes `loop_vars`, and produces a scalar MXNet symbol, + indicating the termination of the loop. + The loop ends when `cond` returns false (zero). + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => Symbol`. + + `func` is a user-defined function as the loop body. + It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. + The number of elements, shape, dtype of each element in `step_output` should be consistent. + The `new_loop_vars` should be consistent with `loop_vars` on each step. + The `func` is variadic, and its signature should be + `cond(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`. + + `max_iterations` is a scalar that defines the maximum number of iterations allowed. + + This function returns a list of Symbols of length `|step_output| + |loop_vars|`. + The i-th element in the first `|step_output|` ones of the list represent + the i-th `step_output` at all step, stacked along axis 0. + The i-th element in the last `|loop_vars|` ones of the list + represent the final state of each loop variable. + + TODO(Junru): writing style: use Symbol or symbol? + Parameters + ---------- + loop_vars: list of Symbol. + The initial values of the loop variables. + cond: a Python function. + The loop condition. + func: a Python function. + The loop body. + max_iteration: a python int. + Maximum number of iterations. + + Returns + ------- + outputs: a list of Symbol of length `|step_output| + |loop_vars|`. + The first `|step_output|` Symbols are outputs. + The last `|loop_vars|` Symbols are the final state of loop variables. + TODO(Junru): change the output format + + Examples + -------- + TODO(Junru): run this + >>> cond = lambda i, s: i <= 5 + >>> func = lambda i, s: (i + 1, s + i) + >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s')) + >>> outputs = mx.sym.contrib.while_loop(loop_vars, cond, func, max_iterations=10) + """ + def _to_python_scalar(inputs, type, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if hasattr(inputs, "asscalar"): + inputs = inputs.asscalar() + try: + inputs = type(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type.__name__)) + return inputs + + def _to_symbol_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, + a tuple of mxnet Symbol, into a tuple of Symbol + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, Symbol): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + for item in inputs: + if not isinstance(item, Symbol): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + return inputs + + def _cond_wrapper(loop_vars): + result = cond(*loop_vars) + if not isinstance(result, Symbol): + raise ValueError("Return of cond must be a Symbol") + return [], [result] + + def _func_wrapper(loop_vars): + """This wrapper unifies + "func: loop_vars -> new_loop_vars" + and "func: loop_vars -> (step_output, new_loop_vars)" + into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars) + """ + step_output, new_loop_vars = func(*loop_vars) + if step_output is None: + step_output = [] + if new_loop_vars is None: + new_loop_vars = [] + step_output = _to_symbol_tuple(step_output, "step_output") + new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars") + if len(loop_vars) != len(new_loop_vars): + raise ValueError("The number of loop_vars should be consistent during the loop") + return list(step_output), list(new_loop_vars) + + def _create_subgraph(graph_vars, graph_func, subgraph_name): + with AttrScope(__subgraph_name__=subgraph_name): + # create new variables with the same name, + # them feed them to the given func + new_graph_vars = [symbol.var(sym.name) for sym in graph_vars] + outputs, final_state = graph_func(new_graph_vars) + # first `num_out_data` elements belong to `outputs` + # other elements belong to `final_state` + num_out_data = len(outputs) + num_outputs = len(outputs) + len(final_state) + # nnvm graph does not allow inputs and outputs overlap + id_new_graph_vars = {id(x) for x in new_graph_vars} + make_identity = lambda x: symbol.op.identity(x) if id(x) in id_new_graph_vars else x + # group all outputs of graph_func + graph = symbol.Group(list(map(make_identity, outputs + final_state))) + return graph, num_out_data, num_outputs + + def _union_inputs(*graphs): + # Given a list of graphs, each whose inputs are either from loop_vars or other variables. + # 1) calculate a list `inputs`, the union of their inputs. + # 2) for each graph, determine in which indices their inputs reside in `inputs` + # 3) for each variable in the input of `graph`, find which index it is + inputs = [] # List[Symbol], result of 1) + locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, where tuples are results of 2) and 3) + input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it to a `loc`, where inputs[loc] = sym + for graph in graphs: + # input_syms: all inputs to the `graph` + name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} + # some loop_vars are inputs to `graph`, some are not + name_to_loop_vars = {sym.name: sym for sym in loop_vars} + # other inputs to `graph` created by cut_graph + name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} + # also we collect the mapping from var's name to var's loc in loop_vars + name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)} + # collect arguments for each subgraph + input_locs = [] # results from the second step + var_locs = [-1] * len(loop_vars) # results from the third step + for name in graph.list_inputs(): + assert name in name_to_input_syms # it should obviously hold + # name -> sym + if name in name_to_loop_vars: + sym = name_to_loop_vars[name] + elif name in name_to_cut_g_syms: + sym = name_to_cut_g_syms[name] + else: + sym = copy.deepcopy(name_to_input_syms[name]) + # do 2), and 1) is implicitly done + if id(sym) in input_id_to_loc: + loc = input_id_to_loc[id(sym)] + else: + loc = len(input_id_to_loc) + inputs.append(sym) + input_id_to_loc[id(sym)] = loc + input_locs.append(loc) + # do 3) + if name in name_to_var_locs: + var_locs[name_to_var_locs[name]] = len(input_locs) - 1 + locs.append((input_locs, var_locs)) + return inputs, locs + max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") + loop_vars = _to_symbol_tuple(loop_vars, "loop_vars") + # It should be work as fine if loop_vars are empty I guess, + # but it is semantically unnecessary to include this case. + if len(loop_vars) == 0: + raise ValueError("loop_vars should contain at least one element") + # create graph for `cond' + cond_g, num_out_data, num_outputs = \ + _create_subgraph(loop_vars, _cond_wrapper, name + "_cond") + assert num_out_data == 0 + assert num_outputs == 1 + # create graph for `func` + func_g, num_out_data, num_outputs = \ + _create_subgraph(loop_vars, _func_wrapper, name + "_func") + # find symbols used in either cond_g or func_g + input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = _union_inputs(cond_g, func_g) + for loc in func_var_locs: + # TODO(Junru): re-examine this + assert loc != -1 + result = symbol._internal._while_loop( + # [cond, func_g, *input_syms] + cond_g, + func_g, + *input_syms, + max_iterations=max_iterations, + cond_input_locs=cond_input_locs, + func_input_locs=func_input_locs, + func_var_locs=func_var_locs, + num_out_data=num_out_data, + num_outputs=num_outputs + ) + outputs = [result[i] for i in range(num_out_data)] + final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] + return outputs, final_loop_vars diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index c091fdb67e0f..9e8045270dc7 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -480,6 +480,521 @@ ForeachGradient(const nnvm::NodePtr& n, const std::vector& ogra return entries; } +struct WhileLoopParam : public dmlc::Parameter { + int num_args; + int num_outputs; + int num_out_data; + int max_iterations; + // `cond' and `func' each takes a subset of while_loop's inputs as that to their subgraphs + // `cond_input_locs' contains indices of inputs fed to `cond', and + // `func_input_locs' contains indices of inputs fed to `func'. + // `func_var_locs' are indices in which input "variables" are stored in func's inputs. + nnvm::Tuple cond_input_locs; + nnvm::Tuple func_input_locs; + nnvm::Tuple func_var_locs; + DMLC_DECLARE_PARAMETER(WhileLoopParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) + .describe("Number of input arguments, including cond and func as two symbol inputs."); + DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) + .describe("The number of outputs of the subgraph, including outputs from the function body, and all loop variables."); + DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0) + .describe("The number of outputs from the function body."); + DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1) + .describe("Maximum number of iterations."); + DMLC_DECLARE_FIELD(cond_input_locs) + .describe("The locations of cond's inputs in the given inputs."); + DMLC_DECLARE_FIELD(func_input_locs) + .describe("The locations of func's inputs in the given inputs."); + DMLC_DECLARE_FIELD(func_var_locs) + .describe("The locations of loop_vars among func's inputs."); + } +}; // struct WhileLoopParam + +DMLC_REGISTER_PARAMETER(WhileLoopParam); + +class WhileLoopState: public LoopState { + public: + WhileLoopParam params; + Symbol cond; // symbol of the `cond' subgraph + size_t n_iterations; // the actual number of steps taken in this while loop, <= max_iterations + CachedOpPtr cond_op; + // abbrev for output_input_mapping + // indicates to which index the output of `func' will be copied to the input of `cond' + std::vector oi_map; + + WhileLoopState(const WhileLoopParam ¶ms, const Symbol &cond, const Symbol &func) : + LoopState(func), + params(params), + cond(cond), + n_iterations(0U), + cond_op(LoopState::MakeSharedOp(cond)), + oi_map(params.func_var_locs.ndim(), -1) { + const nnvm::Tuple &func_input_locs = params.func_input_locs; + const nnvm::Tuple &func_var_locs = params.func_var_locs; + const nnvm::Tuple &cond_input_locs = params.cond_input_locs; + for (size_t i = 0; i < func_var_locs.ndim(); ++i) { + dim_t pos_i = func_input_locs[func_var_locs[i]]; + for (size_t j = 0; j < cond_input_locs.ndim(); ++j) { + dim_t pos_j = cond_input_locs[j]; + if (pos_i == pos_j) { + this->oi_map[i] = j; + } + } + } + } + template + static void extract_by_loc(const std::vector &array, + const nnvm::Tuple input_locs, + std::vector *out) { + out->clear(); + out->reserve(input_locs.ndim()); + for (dim_t i : input_locs) { + out->push_back(array[i]); + } + } + static bool is_shape_udf(const TShape &x) { + return x.ndim() == 0 || x.Size() == 0; + } + static bool is_stype_udf(const int &x) { + return x == exec::kBadStorageID; + } + static bool is_type_udf(const int &x) { + return x == -1; + } + template + static bool fill_value(T &x, T &y, bool x_empty, bool y_empty) { + if (x == y || (x_empty && y_empty)) { + return true; + } + if (!x_empty && !y_empty) { + return false; + } + if (x_empty) { + x = y; + } + if (y_empty) { + y = x; + } + return true; + } + template + static bool sync_in_in(const nnvm::Tuple &input_locs, std::vector *in, std::vector *subg_in, std::function is_empty) { + for (size_t i = 0; i < input_locs.ndim(); ++i) { + T &x = in->at(input_locs[i]); + T &y = subg_in->at(i); + fill_value(x, y, is_empty(x), is_empty(y)); + } + return true; + } + template + static bool sync_in_out(const WhileLoopParam& params, std::vector *in, std::vector *out, std::function is_empty) { + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + // each out->at(i) is a params, loop_var + T &x = in->at(params.func_input_locs[params.func_var_locs[i - params.num_out_data]]); + T &y = out->at(i); + fill_value(x, y, is_empty(x), is_empty(y)); + } + return true; + } +}; + +template +T _asscalar(const NDArray &a) { + CHECK_EQ(a.shape().Size(), 1U); + T data; + a.SyncCopyToCPU(&data, 1U); + return data; +} + +bool as_bool_scalar(const NDArray &a) { + MSHADOW_TYPE_SWITCH(a.dtype(), DType, { + return bool(_asscalar(a)); + }); + CHECK(false) << "Unknown dtype"; + return false; +} + +// TODO(Junru): delete it +void print_scalar(const NDArray &a) { + MSHADOW_TYPE_SWITCH(a.dtype(), DType, { + DType typed_result = _asscalar(a); + std::cout << a.dtype() << " " << typed_result << std::endl; + }); +} + +static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // The argument `inputs' are loop_vars and other inputs + // loop_vars are stored in stored in `loop_vars_locs' + // The argument `outputs' are output and new_loop_vars + // [0: num_out_data) are outputs at each step. + // [num_out_data: ) are new_loop_vars + // TODO(Junru): avoid dynamic NDArray allocation + WhileLoopState &state = state_ptr.get_state(); + const WhileLoopParam& params = state.params; + // a helper function, converting std::vector to std::vector + const auto to_ptr_vec = [](std::vector &in, std::vector *out) { + out->clear(); + out->reserve(in.size()); + std::transform(std::begin(in), std::end(in), std::back_inserter(*out), [](NDArray &a) {return &a;}); + }; + // sanity checks + CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args); + CHECK_EQ(outputs.size(), (size_t) params.num_outputs); + CHECK_EQ(outputs.size(), req.size()); + for (size_t i = 0; i < (size_t) params.num_out_data; i++) + CHECK_EQ(params.max_iterations, outputs[i].shape()[0]); + for (const auto &arr : outputs) + CHECK_EQ(arr.storage_type(), kDefaultStorage) << "The while_loop operator doesn't support the sparse format"; + // construct inputs and outputs for cond + std::vector cond_inputs, cond_outputs = {NDArray()}; + WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); + std::vector cond_input_ptr, cond_output_ptr; + to_ptr_vec(cond_inputs, &cond_input_ptr); + to_ptr_vec(cond_outputs, &cond_output_ptr); + // construct inputs and outputs for func + std::vector func_inputs, func_outputs(outputs.size()); + WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs); + for (size_t &step = state.n_iterations = 0; step < (size_t) params.max_iterations; ++step) { + state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); + if (!as_bool_scalar(*cond_output_ptr[0])) { + break; + } + // we create func_outputs for the current step: + // func_outputs[0: num_out_data] is a slice of outputs[][step] + for (size_t i = 0; i < (size_t) params.num_out_data; ++i) { + func_outputs[i] = outputs[i].At(step); + } + // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new memory + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + } + state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad); + // func_inputs on the next step: + // the output (new_loop_vars) will become the new inputs (loop_vars) + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape()); + func_inputs[j] = func_outputs[i]; + int k = state.oi_map[i - params.num_out_data]; + if (k != -1) { + // I actually don't need to update cond_inputs + cond_inputs[k] = func_outputs[i]; + cond_input_ptr[k] = &func_outputs[i]; + } + } + } + // copy output data to `outputs' + // case 1: at least one step is executed, + // the final_loop_vars must be stored in func_inputs + // case 2: no step is executed + // the final_loop_vars is the same as loop_vars, which are also stored in func_inputs + // therefore, we copy func_inputs[:] to outputs[num_out_data: ] + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + mxnet::CopyFromTo(func_inputs[j], &outputs[i]); + } +} + +// TODO(Junru): delete helper func +void _print_shape(const TShape &s) { + std::cout << "["; + for (auto i : s) { + std::cout << " " << i; + } + std::cout << " ]" << std::endl; +} + +void _ps(const std::vector &shapes) { + for (const TShape &s : shapes) { + _print_shape(s); + } +} + +static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& _req, + const std::vector& _outputs) { + // inputs are dl / df(x) + // outputs are dl / dx + // where f is the current function, + // x is the input to the current function, + // TODO(Junru): avoid dynamic NDArray allocation + WhileLoopState &state = state_ptr.get_state(); + const WhileLoopParam& params = state.params; + // sanity checks + CHECK_EQ(_outputs.size() + 2U, (size_t) params.num_args); + CHECK_EQ(_outputs.size(), _req.size()); + for (auto x : _req) { + CHECK_NE(x, kWriteInplace); + } + for (auto x: _outputs) { + CHECK_EQ(x.storage_type(), kDefaultStorage) << "The while_loop operator doesn't support the sparse format"; + } + std::vector outputs; + std::vector req; + WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs); + WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req); + if (state.n_iterations == 0) { + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + int j = params.func_var_locs[i - params.num_out_data]; + mxnet::CopyFromTo(inputs[i], &outputs[j]); + } + state.Cleanup(); + return; + } + // collect var_locs and out_locs, positions other than var_locs are out_locs, i.e. + // [0, var_locs[0]) + // (var_locs[1], var_locs[2]), + // (var_locs[2], var_locs[3]), + // ... + // (var_locs[-2], var_locs[-1] = params.num_args - 2) + std::vector var_locs(params.func_var_locs.begin(), params.func_var_locs.end()); + var_locs.push_back((dim_t) params.num_args - 2U); + sort(var_locs.begin(), var_locs.end()); + // vectors for the backward loop + std::vector ograds(params.num_outputs); + std::vector igrads(outputs.size()); + std::vector iter_req(req.size()); + for (int i = params.num_out_data; i < params.num_outputs; ++i) + ograds[i] = inputs[i]; + for (int step = (int) state.n_iterations - 1; step >= 0; --step) { + // ograds[ : num_out_data] = inputs[ : num_out_data][step] + // ograds[num_out_data: ] is maintained in the end of each loop + std::transform(std::begin(inputs), + std::begin(inputs) + params.num_out_data, + std::begin(ograds), + [step] (const NDArray &a) { return a.At(step); } ); + // igrads[i] = + // outputs[i] (step == 0) + // outputs[i] (step != 0 && i not in loop_var_locs) + // ArrayLike(outputs[i]) (step != 0 && i in loop_var_locs) + // iter_req = + // kWriteTo (step != 0 && i in loop_var_locs) + // req[i] (step == 0 && i in loop_var_locs) + // kAddTo (step != n_iters - 1 && i not in loop_var_locs) + // req[i] (step == n_iters - 1 && i not in loop_var_locs) + { + size_t i = 0; + for (size_t loc : var_locs) { + for ( ; i < loc; ++i) { + // locs other that var_locs + igrads[i] = outputs[i]; + iter_req[i] = (step + 1 == (int) state.n_iterations || req[i] == kNullOp) + ? req[i] + : kAddTo; + } + if (i < (size_t) params.num_args - 2U) { + // a var + igrads[i] = (step == 0) + ? outputs[i] + : NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + iter_req[i] = (step == 0 || req[i] == kNullOp) + ? req[i] + : kWriteTo; + ++i; + } + else { + break; + } + } + } + state.Backward(step, ograds, iter_req, igrads); + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + ograds[i] = igrads[j]; + } + } + state.Cleanup(); +} + +static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using nnvm::ShapeVector; + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_shape_udf; + // sanity checks + CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + // infer shape for cond and func + auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, + ShapeVector *_subg_out, + const nnvm::Tuple &input_locs, + int num_out_data, + bool fill_out_shape) { + // create subg_in + ShapeVector subg_in; + ShapeVector &subg_out = *_subg_out; + WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in); + // create an indexed graph + nnvm::Graph g; + g.outputs = subg->outputs; + const auto& idx = g.indexed_graph(); + // get input nodes + const auto &input_nids = idx.input_nodes(); + // sanity checks + CHECK_EQ(input_nids.size(), subg_in.size()); + CHECK_EQ(g.outputs.size(), subg_out.size()); + CHECK_EQ(idx.input_nodes().size(), subg_in.size()); + CHECK_EQ(idx.outputs().size(), subg_out.size()); + // create empty shapes for inference + ShapeVector shapes(idx.num_node_entries()); + // copy subg_in into shapes + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + shapes[eid] = subg_in[i]; + } + // copy subg_out into shapes + // note that ndim of out_data is not increased + // because subg is only one step + for (size_t i = 0; i < subg_out.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + shapes[eid] = subg_out[i]; + } + // copy done, call InferShape + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + // now `shapes' won't be used anymore, use new_shapes instead + const auto& new_shapes = g.GetAttr("shape"); + // copy subg_in back to in_shape + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape); + } + if (!fill_out_shape) { + return true; + } + // copy subg_out back to out_shape + // for results in [0, num_out_data), ndim should increase by 1 + for (int i = 0; i < num_out_data; ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + auto out = TShape(g_out_shape.ndim() + 1); + out[0] = params.max_iterations; + for (size_t i = 1; i < out.ndim(); i++) + out[i] = g_out_shape[i - 1]; + SHAPE_ASSIGN_CHECK(*out_shape, i, out); + } + // for results in [num_out_data, ...), ndim does not change + for (size_t i = num_out_data; i < g.outputs.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape); + } + return g.GetAttr("shape_num_unknown_nodes") == 0; + }; + ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] + ShapeVector func_out_shape(params.num_outputs); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, params.func_input_locs, params.num_out_data, true); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + return succ_0 && succ_1; +} + +static bool WhileLoopType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_type_udf; + CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_type->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + std::vector cond_in_type; + std::vector func_in_type; + WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); + WhileLoopState::extract_by_loc(*in_type, params.func_input_locs, &func_in_type); + std::vector cond_out_type = {0}; + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); + bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, out_type); + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf)); + return succ_0 && succ_1; +} + +static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_stype_udf; + CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + std::vector cond_in_attrs; + std::vector func_in_attrs; + WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); + WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs); + std::vector cond_out_attrs = {kDefaultStorage}; + DispatchMode cond_mode = DispatchMode::kUndefined; + DispatchMode func_mode = DispatchMode::kUndefined; + *dispatch_mode = DispatchMode::kFComputeEx; + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, &cond_mode, &cond_in_attrs, &cond_out_attrs); + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); + bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, &func_mode, &func_in_attrs, out_attrs); + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); + return succ_0 && succ_1; +} + +static bool BackwardWhileLoopStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + // `cond' is not backwarded, don't check + const WhileLoopParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size() + 2U, (size_t) params.num_args); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CachedOp op(*attrs.subgraphs[1], {}); + return op.BackwardStorageType(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +} + +static OpStatePtr CreateWhileLoopState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create(params, *attrs.subgraphs[0], *attrs.subgraphs[1]); +} + +static std::vector +WhileLoopGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_while_loop"}; + std::vector entries = fgrad(n, ograds); + entries[0].node->attrs.subgraphs = n->attrs.subgraphs; + return entries; +} + NNVM_REGISTER_OP(_foreach) .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation") .set_attr_parser(ParamParser) @@ -526,11 +1041,11 @@ NNVM_REGISTER_OP(_backward_foreach) .set_num_inputs([](const NodeAttrs& attrs){ const ForeachParam& params = nnvm::get(attrs.parsed); return params.num_outputs * 2 + params.num_args - 1; - }) +}) .set_num_outputs([](const NodeAttrs& attrs){ const ForeachParam& params = nnvm::get(attrs.parsed); return params.num_args - 1; - }) +}) .set_attr("FExecType", [](const NodeAttrs& attrs) { return ExecType::kSubgraphExec; }) @@ -541,5 +1056,67 @@ NNVM_REGISTER_OP(_backward_foreach) .set_attr("FStatefulComputeEx", ForeachGradComputeExCPU) .set_attr("FStatefulComputeEx", ForeachGradComputeExCPU); +NNVM_REGISTER_OP(_while_loop) +.MXNET_DESCRIBE("Run a while loop over with user-defined condition and computation") +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", WhileLoopStorageType) +.set_num_inputs([](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_outputs; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + std::vector names; + names.reserve(params.num_args); + names.push_back("cond"); + names.push_back("func"); + for (int i = 2; i < params.num_args; i++) + names.push_back("data" + std::to_string(i - 2)); + return names; +}) +.set_attr("FInputGraph", + [](const NodeAttrs& attrs) { + return std::vector{0, 1}; +}) +.set_attr("FGradient", WhileLoopGradient) +.set_attr("FCreateOpState", CreateWhileLoopState) +.set_attr("FInferShape", WhileLoopShape) +.set_attr("FInferType", WhileLoopType) +.set_attr("FStatefulComputeEx", WhileLoopComputeExCPU) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FStatefulComputeEx", WhileLoopComputeExCPU) +.set_attr("key_var_num_args", "num_args") +.add_argument("cond", "Symbol", "Input graph for the loop condition.") +.add_argument("func", "Symbol", "Input graph for the loop body.") +.add_argument("data", "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") +.add_arguments(WhileLoopParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_while_loop) +.set_num_inputs([](const NodeAttrs& attrs){ + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_outputs * 2 + params.num_args - 2; +}) +.set_num_outputs([](const NodeAttrs& attrs){ + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_args - 2; +}) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FInferStorageType", BackwardWhileLoopStorageType) +.set_attr_parser(ParamParser) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU) +.set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU); + } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index 71a9a21c28c4..d845aa907d33 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -164,14 +164,7 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph, LoopState::LoopState(const Symbol &g) { this->subgraph_sym = g; this->subgraph.outputs = g.outputs; - - std::vector > kwargs; - kwargs.push_back(std::pair("inline_limit", "0")); - // We turn on static_alloc for two reasons. - // It avoids the overhead of unnecessary memory allocation. - // only static_alloc supports nested call of CachedOp. - kwargs.push_back(std::pair("static_alloc", "1")); - iter_op = std::make_shared(subgraph_sym, kwargs); + this->iter_op = LoopState::MakeSharedOp(g); } void LoopState::Forward(int iter_no, diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index 79078409e214..a5a54620b166 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -69,8 +69,8 @@ class LoopState { // For training, each iteration has a cached op because each iteration // needs to maintain a set of memory buffers for all computation states, // which will be used in the backward. - CachedOpPtr iter_op; std::vector all_states; + CachedOpPtr iter_op; Symbol subgraph_sym; nnvm::Graph subgraph; @@ -91,6 +91,16 @@ class LoopState { all_inputs.clear(); all_states.clear(); } + static CachedOpPtr MakeSharedOp(const Symbol &sym) { + // We turn on static_alloc for two reasons. + // It avoids the overhead of unnecessary memory allocation. + // only static_alloc supports nested call of CachedOp. + std::vector > kwargs = { + {"inline_limit", "0"}, + {"static_alloc", "1"} + }; + return std::make_shared(sym, kwargs); + } }; } // namespace op diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_while_loop.py new file mode 100644 index 000000000000..5f4b04d92f02 --- /dev/null +++ b/tests/python/unittest/test_while_loop.py @@ -0,0 +1,493 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet import gluon +import numpy as np +import copy +from numpy.testing import assert_allclose +import unittest +from mxnet.test_utils import almost_equal, default_context +from numpy.testing import assert_allclose as assert_almost_equal # This is more restrictive + + +def test_simple_add(): + + class _TestBlock(gluon.HybridBlock): + + def __init__(self, cond, func, max_iterations): + super(_TestBlock, self).__init__() + self.cond = cond + self.func = func + self.max_iterations = max_iterations + + def hybrid_forward(self, F, *loop_vars): + return F.contrib.while_loop( + cond=self.cond, + func=self.func, + loop_vars=loop_vars, + max_iterations=self.max_iterations + ) + + for hybridize in [False, True]: + # Case 1.1: result should be sum([1, 2, 3 ... 100]) + model = _TestBlock( + cond=lambda i, s: i <= 5, + func=lambda i, s: (None, (i + 1, s + i)), + max_iterations=10, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + ) + assert result[0].asscalar() == 6 + assert result[1].asscalar() == 15 + # Case 1.2: result should be sum([1, 2, 3 ... 1000]) + model = _TestBlock( + cond=lambda i, s, true: true, + func=lambda i, s, true: (None, (i + 1, s + i, true)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([1], dtype="int64"), # true + ) + assert result[0].asscalar() == 1001 + assert result[1].asscalar() == 500500 + assert result[2].asscalar() == 1 + # Case 1.3: result should be sum([]) + model = _TestBlock( + cond=lambda i, s, false: false, + func=lambda i, s, false: (None, (i + 1, s + i, false)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([0], dtype="int64"), # false + ) + assert result[0].asscalar() == 1 + assert result[1].asscalar() == 0 + assert result[2].asscalar() == 0 + # Case 2.1: result should be sum([1, 2, 3 ... 100]) + model = _TestBlock( + cond=lambda i, s: i <= 100, + func=lambda i, s: (i, (i + 1, s + i)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + (outputs, ), (result_i, result_s) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + ) + assert all(outputs.asnumpy()[ : 100] == np.arange(1, 101).reshape(100, 1)) + assert result_i.asscalar() == 101 + assert result_s.asscalar() == 5050 + # Case 2.2: result should be sum([1, 2, 3 ... 1000]) + model = _TestBlock( + cond=lambda i, s, true: true, + func=lambda i, s, true: (i, (i + 1, s + i, true)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + (outputs, ), (result_i, result_s, _) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([1], dtype="int64"), # true + ) + assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1)) + assert result_i.asscalar() == 1001 + assert result_s.asscalar() == 500500 + # Case 2.3: very corner case + model = _TestBlock( + cond=lambda i, s, false: false, + func=lambda i, s, false: (i, (i + 1, s + i, false)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, (result_i, result_s, _) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([0], dtype="int64"), # false + ) + assert result_i.asscalar() == 1 + assert result_s.asscalar() == 0 + + +def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for): + + def _create_vars(num, prefix): + return [mx.sym.var(prefix + str(i)) for i in range(num)] + + def _create_arrays(shapes): + return [mx.nd.random.uniform(-1.0, 1.0, shape=x) for x in shapes] + + def _create_dict(prefix, shapes): + return {prefix + str(i): mx.nd.random.uniform(-1.0, 1.0, shape=x) for i, x in enumerate(shapes)} + + def _merge_dict(*dicts): + result = {} + for item in dicts: + result.update(item) + return result + + def _to_numpy_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _get_imperative_result(): + free_vars = [args["FreeVar" + str(i)].copy() for i, _ in enumerate(free_var_shapes)] + loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in enumerate(loop_var_shapes)] + loop_var_start = int(is_for) + if is_train: + for var in free_vars + loop_vars[loop_var_start: ]: + var.attach_grad() + with mx.autograd.record(train_mode=is_train): + outputs, final_loop_vars = mx.nd.contrib.while_loop( + cond=lambda *_loop_vars: cond(_loop_vars, free_vars), + func=lambda *_loop_vars: func(_loop_vars, free_vars), + loop_vars=loop_vars, + max_iterations=max_iterations, + ) + n_steps = outputs[0].shape[0] if outputs else 0 + out_grads = _create_arrays(x.shape for x in outputs) \ + + _create_arrays(x.shape for x in final_loop_vars) + loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in final_loop_vars] + grads = [] + if is_train: + cat_out = mx.nd.concat(*[x.reshape(-1) for x in loop_result_nd], dim=0) + cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) + grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ + + [loop_vars[i].grad for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), out_grads, n_steps + + def _get_symbolic_result(out_grads, n_steps): + + def _copy_args_dict(name_list): + return {name: args[name].copy() for name in name_list} + + def _zeros_like_dict(name_list): + return {name: mx.nd.zeros_like(args[name]) for name in name_list} + + free_syms = _create_vars(len(free_var_shapes), "FreeVar") + loop_syms = _create_vars(len(loop_var_shapes), "LoopVar") + outputs, final_loop_syms = mx.sym.contrib.while_loop( + cond=lambda *_loop_vars: cond(_loop_vars, free_syms), + func=lambda *_loop_vars: func(_loop_vars, free_syms), + loop_vars=loop_syms, + max_iterations=max_iterations, + ) + if n_steps == 0: + outputs = [] + else: + outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in outputs] + loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in final_loop_syms] + loop_result_sym = mx.sym.Group(loop_result_sym) + + loop_var_start = int(is_for) + args_names = ["FreeVar" + str(i) for i, _ in enumerate(free_var_shapes)] \ + + ["LoopVar" + str(i) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + args_grad = None if not is_train else _zeros_like_dict(x for x in args_names) + executor = loop_result_sym.bind( + ctx=default_context(), + args=_copy_args_dict(loop_result_sym.list_inputs()), + args_grad=args_grad, + ) + loop_result_nd = executor.forward(is_train=is_train) + grads = [] + if is_train: + executor.backward(out_grads=out_grads) + grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \ + + [executor.grad_dict.get("LoopVar" + str(i), None) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads) + + args = _merge_dict( + _create_dict("FreeVar", free_var_shapes), + _create_dict("LoopVar", loop_var_shapes), + ) + if is_for: + assert loop_var_shapes[0] == (1, ) + args["LoopVar0"] = mx.nd.array([0]) + imp_outs, imp_grads, out_grads, n_steps = _get_imperative_result() + sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps) + for imp_out, sym_out in zip(imp_outs, sym_outs): + if imp_out is None or sym_out is None: + continue + assert_almost_equal(imp_out, sym_out) + for imp_grad, sym_grad in zip(imp_grads, sym_grads): + if imp_grad is None or sym_grad is None: + continue + assert_almost_equal(imp_grad, sym_grad, rtol=1e-5, atol=1e-5) + + +def test_while_loop_for_foreach(): + + def make_true_cond(): + return lambda loop_vars, _: (loop_vars[0] < 1e9).prod() + + def make_false_cond(): + return lambda loop_vars, _: (loop_vars[0] > 1e9).prod() + + def make_for_cond(length): + return lambda loop_vars, _: loop_vars[0] < length + + def case_0(): + def _simple_func(loop, free): + (i, ), (scanned, ) = loop, free + in_ = scanned.take(i).squeeze(axis=0) + return (in_, i + 1) + _verify_while_loop( + cond=make_true_cond(), + func=_simple_func, + max_iterations=1, + is_train=True, + is_for=True, + loop_var_shapes=[ + (1, ), # i + ], + free_var_shapes=[ + (1, 3), # scanned + ], + ) + + def case_1(**params): + step_funcs = [ + lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5, + lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5, + lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5, + lambda a, b, s: b * 2.5 - s * 3.5 + a * 1.5, + lambda a, b, s: s * -3.5 + a * 1.5 + b * 2.5, + lambda a, b, s: s * -3.5 + b * 2.5 + a * 1.5, + lambda a, b, s: a * 2.5 * b + s * 0.3, + lambda a, b, s: b * 2.5 * a + s * 0.3, + lambda a, b, s: 2.5 * a * b + s * 0.3, + lambda a, b, s: b * a * 2.5 + s * 0.3, + lambda a, b, s: 2.5 * b * a + s * 0.3, + lambda a, b, s: b * a * 2.5 + s * 0.3, + lambda a, b, s: s * 0.3 + a * 2.5 * b, + lambda a, b, s: s * 0.3 + b * 2.5 * a, + lambda a, b, s: s * 0.3 + 2.5 * a * b, + lambda a, b, s: s * 0.3 + b * a * 2.5, + lambda a, b, s: s * 0.3 + 2.5 * b * a, + lambda a, b, s: s * 0.3 + b * a * 2.5, + ] + def make_func(step_func): + def step(loop, free): + (s, ), (a, b) = loop, free + out = step_func(a, b, s) + return (out, out) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + is_train=is_train, + is_for=False, + **params + ) + + def case_2(**params): + # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + step_funcs = [ + lambda in_, s, f_1: (in_ * 2) * s * f_1, + lambda in_, s, f_1: (in_ * 2) * f_1 * s, + lambda in_, s, f_1: s * (in_ * 2) * f_1, + lambda in_, s, f_1: s * f_1 * (in_ * 2), + lambda in_, s, f_1: f_1 * (in_ * 2) * s, + lambda in_, s, f_1: f_1 * s * (in_ * 2), + lambda in_, s, f_1: (2 * in_) * s * f_1, + lambda in_, s, f_1: (2 * in_) * f_1 * s, + lambda in_, s, f_1: s * (2 * in_) * f_1, + lambda in_, s, f_1: s * f_1 * (2 * in_), + lambda in_, s, f_1: f_1 * (2 * in_) * s, + lambda in_, s, f_1: f_1 * s * (2 * in_), + ] + def make_func(step_func): + """This simulates: + def compute(s, inputs, f_1, length): + outputs = [] + for i in range(length): + s += inputs[i] * 2 + f_1 + outputs.append(s) + return outputs, s + """ + def step(loop, free): + (i, s), (scanned, f_1, _) = loop, free + in_ = scanned.take(i).squeeze(axis=0) + out = step_func(in_, s, f_1) + return (out, (i + 1, out)) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_3(length, **params): + # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + ] + def make_func(step_func): + """This simulates: + def compute(s, inputs, f_1, length): + outputs = [] + for i in range(length): + s += inputs[i] * 2 + f_1 + outputs.append(s) + return outputs, s + """ + def step(loop, free): + (i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + return ([out, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + # Case 0: the simpest case + print("Testing Case 0") + case_0() + # Case 1.1.* + print("Testing Case 1.1") + case_1( + cond=make_true_cond(), + loop_var_shapes=[ + (1, ), # s + ], + free_var_shapes=[ + (1, ), # a + (1, ), # b + ], + max_iterations=23, + ) + # Case 1.2.* + print("Testing Case 1.2") + case_1( + cond=make_true_cond(), + loop_var_shapes=[ + (2, 3, 4), # s + ], + free_var_shapes=[ + (2, 3, 4), # a + (2, 3, 4), # b + ], + max_iterations=31, + ) + # Case 1.3.* + print("Testing Case 1.3") + case_1( + cond=make_false_cond(), + loop_var_shapes=[ + (2, 3, 4), # s + ], + free_var_shapes=[ + (2, 3, 4), # a + (2, 3, 4), # b + ], + max_iterations=20, + ) + # Case 2.1.* + print("Testing Case 2.1") + case_2( + cond=make_for_cond(length=31), + loop_var_shapes=[ + (1, ), # i + (2, ), # s + ], + free_var_shapes=[ + (100, 2), # scanned + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + ) + # Case 2.2.* + print("Testing Case 2.2") + case_2( + cond=make_for_cond(length=25), + loop_var_shapes=[ + (1, ), # i + (2, ), # s + ], + free_var_shapes=[ + (30, 2), # scanned + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + ) + # Case 3.* + print("Testing Case 3") + case_3( + length=11, + cond=make_for_cond(length=11), + loop_var_shapes=[ + (1, ), # i + (2, ), # s_0 + (2, ), # s_1 + ], + free_var_shapes=[ + (30, 2), # sc_0 + (30, 2), # sc_1 + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + ) + + +if __name__ == '__main__': + # import nose + # nose.runmodule() + test_simple_add() + test_while_loop_for_foreach()