From a0822d3353cb19eb18629280adca7bacf50eb4a7 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 14 Aug 2018 19:34:22 +0300 Subject: [PATCH 01/11] [TVM] Automatic differentiation for tensor expressions --- include/tvm/ir_operator.h | 30 + python/tvm/__init__.py | 2 + python/tvm/autodiff.py | 130 ++ python/tvm/testing.py | 172 ++ src/op/op_util.cc | 43 + src/op/op_util.h | 40 + src/pass/autodiff.cc | 498 +++++ src/pass/autodiff.h | 146 ++ src/pass/dump_tensor.cc | 100 + src/pass/zero_elimination.cc | 1666 +++++++++++++++++ src/pass/zero_elimination.h | 243 +++ tests/python/unittest/test_pass_autodiff.py | 414 ++++ .../unittest/test_pass_zero_elimination.py | 461 +++++ 13 files changed, 3945 insertions(+) create mode 100644 python/tvm/autodiff.py create mode 100644 src/pass/autodiff.cc create mode 100644 src/pass/autodiff.h create mode 100644 src/pass/dump_tensor.cc create mode 100644 src/pass/zero_elimination.cc create mode 100644 src/pass/zero_elimination.h create mode 100644 tests/python/unittest/test_pass_autodiff.py create mode 100644 tests/python/unittest/test_pass_zero_elimination.py diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index f2c9c3d517a5..ded99fb52206 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -85,6 +85,16 @@ inline const uint64_t* as_const_uint(const Expr& x) { */ inline bool is_const_int(const Expr& x, int64_t value); +/*! + * \brief Check if the given expr is a const of any type equal to the given integer value. + * \param e The expression. + * \param value The value to compare to. + * \return Whether the expression is a const equal to the value. + * \tparam ValueType The value type + */ +template +inline bool is_const_value(const Expr& e, ValueType value); + /*! * \brief Check whether stmt is nop. * \param stmt The input statement @@ -515,6 +525,26 @@ inline bool is_const_int(const Expr& x, int64_t value) { return false; } +template +inline bool is_const_value(const Expr& e, ValueType value) { + static_assert(std::is_integral::value, + "Comparison to non-integer values is forbidden."); + // This implementation was copy-pasted from HalideIR + if (const ir::IntImm* i = e.as()) { + return i->value == value; + } else if (const ir::UIntImm* i = e.as()) { + return (value >= 0) && (i->value == (uint64_t)value); + } else if (const ir::FloatImm* i = e.as()) { + return i->value == value; + } else if (const ir::Cast* c = e.as()) { + return is_const_value(c->value, value); + } else if (const ir::Broadcast* b = e.as()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + inline bool is_no_op(const Stmt& stmt) { if (!stmt.defined()) return true; if (const auto* op = stmt.as()) { diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 9c09dc5a4ac3..749b6b3dbb07 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -19,6 +19,7 @@ from . import generic from . import hybrid from . import testing +from . import autodiff from . import ndarray as nd from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl @@ -36,6 +37,7 @@ from .schedule import create_schedule from .build_module import build, lower, build_config from .tag import tag_scope +from .autodiff import differentiate # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel diff --git a/python/tvm/autodiff.py b/python/tvm/autodiff.py new file mode 100644 index 000000000000..63364c42cad0 --- /dev/null +++ b/python/tvm/autodiff.py @@ -0,0 +1,130 @@ +"""Namespace of autodiff-related functions. + +The functions are automatically exported from C++ side via PackedFunc. +You can read "include/tvm/autodiff.h" for the function signature of these functions. +""" +import logging + +from ._ffi.function import _init_api +from ._ffi.node import NodeBase, register_node + +_init_api("tvm.autodiff") + +@register_node +class DifferentiationResult(NodeBase): + """Result of differentiation. + + Parameters + ---------- + result : list of Tensor + The requested adjoints, i.e. the jacobians or gradients of the given output + wrt to the given inputs. + + adjoints : dict from Tensor to Tensor + A map from tensors to the corresponding adjoints (including internal nodes). + + adjoint_summands : dict from Tensor to dict from Tensor to Tensor + Single summands of the adjoints. + """ + def __getattr__(self, name): + # Here we convert tvm Maps to dicts because Map compares keys by reference which is + # wrong for Tensors. Hopefully, in the future Map gets fixed somehow, and this function + # may be removed then. + res = NodeBase.__getattr__(self, name) + if name == 'adjoints': + return dict(res.items()) + if name == 'adjoint_summands': + return {k: dict(v.items()) for k, v in res.items()} + return res + + def __getitem__(self, i): + return self.result[i] + + def __len__(self): + return len(self.result) + + +def differentiate(output, inputs=None, head=None, manual=None, fdiff=None): + """Perform reverse-mode automatic differentiation. + + Example:: + + x = tvm.placeholder((32, 3, 28, 28), name='x') + w1 = tvm.placeholder((10, 3, 3, 3), name='w1') + w2 = tvm.placeholder((10, 10, 3, 3), name='w2') + y = topi.sum(topi.nn.conv2d(topi.nn.conv2d(x, w1, 1, 0), w2, 1, 0)) + + [dw1, dw2] = tvm.differentiate(y, [w1, w2]) + + Parameters + ---------- + output : Tensor + The tensor to differentiate. + + inputs : list of Tensor + The list of input tensors. When the list is empty or None, will perform + differentiation wrt all tensors the output depends on (i.e. will compute all + adjoints and populate the corresponding dict, but the list of results + will be empty). + + head : Tensor + The adjoint of the output, in other words, some tensor, by which the Jacobians + will be multiplied. Its shape must be of the form `prefix + output.shape`. + If `None` is passed, the identity tensor of shape `output.shape + output.shape` + will be used. + + manual : dict (Tensor, Tensor) -> function + A dict providing custom multiplication-differentiation functions (see `fdiff`) + for certain pairs of tensors. Each pair consists of an output and an input tensor, + the input one being an immediate dependency of the output one. Pairs of the form + `(None, tensor)` and `(tensor, None)` are allowed, `None` working as a wildcard. + + fdiff : function (Tensor, Tensor, Tensor) -> Tensor + The default function performing differentiation and multiplication, by default + `tvm.autodiff.FDiffBuildingBlock` is used. The function must accept three + parameters: + - `output` - an output tensor + - `input` - an input tensor + - `head` - the adjoint of the output tensor + The result should be `head` multiplied by the jacobian of `output` wrt `input` + + Returns + ------- + differentiation_result : DifferentiationResult + """ + if inputs is None: + inputs = [] + + if fdiff is None: + fdiff = DiffBuildingBlock + + if manual is not None: + if not isinstance(manual, dict): + manual = dict(manual) + + # pylint: disable=dangerous-default-value + used_items = set() + + def _modified_fdiff(out, inp, head, manual=manual, old_fdiff=fdiff, used_items=used_items): + if (out, inp) in manual: + used_items.add((out, inp)) + return manual[(out, inp)](out, inp, head) + if (out, None) in manual: + used_items.add((out, None)) + return manual[(out, None)](out, inp, head) + if (None, inp) in manual: + used_items.add((None, inp)) + return manual[(None, inp)](out, inp, head) + return old_fdiff(out, inp, head) + + fdiff = _modified_fdiff + + res = Differentiate(output, inputs, head, fdiff) + + if manual is not None: + for k in manual: + if k not in used_items: + logging.warning("The manually specified differentiation function " + "for %s hasn't been used", k) + + return res diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 1a6666bdee2a..e21dc5be6863 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -1,6 +1,7 @@ """ TVM testing utilities """ import logging import numpy as np +import tvm def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): """ Version of np.testing.assert_allclose with `atol` and `rtol` fields set @@ -145,3 +146,174 @@ def compare_derivative(j, n_der, grad): logging.info("Numerical grad test wrt '%s' of shape %s passes, " "dist = %f, max_diff = %f, avg_diff = %f", x_name, grad.shape, dist, max_diff, avg_diff) + + +class PerformanceEstimate: + """A result of static performance estimation. + + Parameters + ---------- + iterations : int + The total number of iterations of all the loops. + + multiplications : int + The total number of expensive operations like multiplications. + + memory : int + The amount of memory to allocate. + """ + def __init__(self, iterations=0, multiplications=0, memory=0): + self.iterations = iterations + self.multiplications = multiplications + self.memory = memory + + def as_tuple(self): + return (self.iterations, self.multiplications, self.memory) + + def __add__(self, other): + return PerformanceEstimate(iterations=self.iterations + other.iterations, + multiplications=self.multiplications + other.multiplications, + memory=self.memory + other.memory) + + def max(self, other): + return PerformanceEstimate( + iterations=max(self.iterations, other.iterations), + multiplications=max(self.multiplications, other.multiplications), + memory=max(self.memory, other.memory)) + + def times(self, iters): + return PerformanceEstimate(iterations=self.iterations*iters, + multiplications=self.multiplications*iters, + memory=self.memory) + + def __repr__(self): + return "PerformanceEstimate(iterations={}, multiplications={}, memory={})".format( + self.iterations, self.multiplications, self.memory) + + def __le__(self, other): + return \ + self.iterations <= other.iterations and \ + self.multiplications <= other.multiplications and \ + self.memory <= other.memory + + +def estimate_performance(s, processed_ops=None): + """Statically estimate performance of statements, expressions and tensors. Note that the + estimate is very rough, it mustn't be used to predict future performance, its only purpose is + to detect possible performance regressions. + + Parameters: + ----------- + s + A statement, an expression, a tensor, an operation, or a list + of any of the above. + + Returns + ------- + estimate : PerformanceEstimate + """ + from tvm import stmt + from tvm import expr + + if processed_ops is None: + processed_ops = {} + res = estimate_performance(s, processed_ops) + for op_est in processed_ops.values(): + res += op_est + return res + + est = lambda e, processed_ops=processed_ops: estimate_performance(e, processed_ops) + + def _prod(elems): + res = 1 + for x in elems: + res *= x + return res + + if s is None or isinstance(s, (stmt.AssertStmt, stmt.Free, stmt.Prefetch, + expr.ConstExpr, expr.Var, tvm.tensor.PlaceholderOp)): + return PerformanceEstimate() + elif isinstance(s, list): + res = PerformanceEstimate() + for item in s: + res += est(item) + return res + elif s in processed_ops: + return PerformanceEstimate() + elif isinstance(s, stmt.Allocate): + mem = _prod([e.value for e in s.extents]) + return est(s.condition) + est(s.body) + PerformanceEstimate(memory=mem) + elif isinstance(s, stmt.Block): + return est(s.first) + est(s.rest) + elif isinstance(s, stmt.Evaluate): + return est(s.value) + elif isinstance(s, stmt.For): + body_est = est(s.body) + body_est.iterations = max(1, body_est.iterations) + return body_est.times(s.extent.value) + elif isinstance(s, stmt.IfThenElse): + return est(s.condition) + est(s.then_case) + est(s.else_case) + elif isinstance(s, stmt.LetStmt): + return est(s.value) + est(s.body) + elif isinstance(s, (stmt.ProducerConsumer, stmt.AttrStmt)): + return est(s.body) + elif isinstance(s, stmt.Provide): + return est(s.value) + elif isinstance(s, stmt.Realize): + return est(s.condition) + est(s.body) + elif isinstance(s, stmt.Store): + return est(s.value) + est(s.index) + est(s.predicate) + elif isinstance(s, (expr.Mul, expr.Div, expr.Mod)): + return est(s.a) + est(s.b) + PerformanceEstimate(multiplications=1) + elif isinstance(s, (expr.BinaryOpExpr, expr.CmpExpr, expr.LogicalExpr)): + if not hasattr(s, 'b'): + return est(s.a) + return est(s.a) + est(s.b) + elif isinstance(s, expr.Call): + res = PerformanceEstimate() + for a in s.args: + res += est(a) + if s.call_type == expr.Call.Halide: + # The estimate is added to processed_ops, we don't need the result here + est(s.func) + elif s.name == "tvm_if_then_else": + pass + else: + # expr.If it is a non-halide call (e.g. exp or log), consider it a mul + res += PerformanceEstimate(multiplications=1) + return res + elif isinstance(s, expr.Cast): + return est(s.value) + elif isinstance(s, expr.Load): + return est(s.index) + est(s.predicate) + elif isinstance(s, expr.Select): + return est(s.condition) + est(s.true_value) + est(s.false_value) + elif isinstance(s, expr.Reduce): + iterations = _prod([iv.dom.extent.value for iv in s.axis]) + res = PerformanceEstimate() + for id_elem in s.combiner.identity_element: + res += est(id_elem) + on_each_iter = est(s.condition) + for src in s.source: + on_each_iter += est(src) + for comb_res in s.combiner.result: + on_each_iter += est(comb_res) + on_each_iter.iterations = max(1, on_each_iter.iterations) + return res + on_each_iter.times(iterations) + elif isinstance(s, tvm.tensor.Tensor): + return est(s.op) + elif isinstance(s, tvm.tensor.ComputeOp): + iterations = _prod([iv.dom.extent.value for iv in s.axis]) + if s.reduce_axis: + res = est(s.body[0]) + else: + res = PerformanceEstimate() + for b in s.body: + res += est(b) + res.iterations = max(1, res.iterations) + res = res.times(iterations) + PerformanceEstimate(memory=iterations*len(s.body)) + processed_ops[s] = res + return PerformanceEstimate() + + raise ValueError("Don't know how to estimate performance of {} of type {}" + .format(s, type(s))) diff --git a/src/op/op_util.cc b/src/op/op_util.cc index b18552d5c562..4231f336a01b 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -245,5 +245,48 @@ ir::ForType IterVarTypeToForType(IterVarType iter_type) { } } +Tensor TensorFromExpr(const Expr& expr, const Array& axis, + const std::string& name, const std::string& tag, + const Map& attrs) { + Array new_bodies; + int new_value_index = 0; + + // If this is a reduction then we have to clone its body + if (const Reduce* red = expr.as()) { + new_value_index = red->value_index; + + for (size_t i = 0; i < red->source.size(); ++i) { + Expr ith_red = Reduce::make(red->combiner, red->source, red->axis, red->condition, i); + new_bodies.push_back(ith_red); + } + } else { + new_value_index = 0; + new_bodies.push_back(expr); + } + + return ComputeOpNode::make(name, tag, attrs, axis, new_bodies).output(new_value_index); +} + +Tensor TransformBody(const Tensor& tensor, + std::function&)> func) { + if (const ComputeOpNode* op = tensor->op.as()) { + // Transform only one body + Expr new_body = func(op->body[tensor->value_index], op->axis); + + // If the body didn't change then we can return the same tensor + if (new_body.same_as(op->body[tensor->value_index])) { + return tensor; + } + + return TensorFromExpr(new_body, op->axis, op->name, op->tag, op->attrs); + } else { + return tensor; + } +} + +Tensor TransformBody(const Tensor& tensor, std::function func) { + return TransformBody(tensor, [func](const Expr& e, const Array&) { return func(e); }); +} + } // namespace op } // namespace tvm diff --git a/src/op/op_util.h b/src/op/op_util.h index de2e44c2ed59..da7987f7162f 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "../pass/ir_util.h" #include "../pass/arg_binder.h" #include "../schedule/message_passing.h" @@ -84,6 +85,45 @@ IterVarType ForTypeToIterVarType(ir::ForType for_type); */ ir::ForType IterVarTypeToForType(IterVarType iter_type); +/*! + * \brief Create a tensor from an expression. The expression may be a reduction, in which + * case its body will be correctly duplicated if it is a multi-valued reduction. + * + * \param expr The expr which will be the tensor's body. + * \param axis The input variables with ranges. + * \param name The tensor's name. + * \param tag The tensor's tag. + * \param attrs The tensor's attrs. + * \return A tensor. + */ +Tensor TensorFromExpr(const Expr& expr, const Array& axis, + const std::string& name = "tensor", const std::string& tag = "", + const Map& attrs = {}); + +/*! + * \brief Transform the body of a tensor if it is a compute tensor, otherwise return it + * unchanged. Note that if the compute returns a tuple, it transforms only one element, + * other elements are discarded. + * + * \param tensor The tensor to transform. + * \param func The transformation function working on expressions and additionally taking + * the array of the tensor's itervars. + * \return The transformed tensor. + */ +Tensor TransformBody(const Tensor& tensor, + std::function&)> func); + +/*! + * \brief Transform the body of a tensor if it is a compute tensor, otherwise return it + * unchanged. Note that if the compute returns a tuple, it transforms only one element, + * other elements are discarded. + * + * \param tensor The tensor to transform. + * \param func The transformation function (working on expressions). + * \return The transformed tensor. + */ +Tensor TransformBody(const Tensor& tensor, std::function func); + } // namespace op } // namespace tvm #endif // TVM_OP_OP_UTIL_H_ diff --git a/src/pass/autodiff.cc b/src/pass/autodiff.cc new file mode 100644 index 000000000000..5afc5d5476f5 --- /dev/null +++ b/src/pass/autodiff.cc @@ -0,0 +1,498 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file autodiff.cc + * \brief Automatic differentiation of IR Expr + */ +#include "autodiff.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/op_util.h" +#include "zero_elimination.h" + +namespace tvm { +namespace ir { + + +DifferentiationResult DifferentiationResultNode::make(Array result, + Map adjoints, + Map> summands) { + auto n = make_node(); + n->result = std::move(result); + n->adjoints = adjoints; + n->adjoint_summands = summands; + return DifferentiationResult(n); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const DifferentiationResultNode* r, IRPrinter* p) { + p->stream << "DifferentiationResult(result=" << r->result + << ", adjoints=" << r->adjoints + << ", adjoint_summands=" << r->adjoint_summands << ')'; + }); + +TVM_REGISTER_NODE_TYPE(DifferentiationResultNode); + + +#define NOT_IMPLEMENTED \ + { CHECK(false) << "Derivative of this op is not implemented"; \ + throw dmlc::Error("Derivative of this op is not implemented"); } + +/*! \brief Differentiate an expression wrt a variable or a tensor element */ +class JacobianMutator : public IRMutator { + public: + /*! + * \brief Differentiate wrt `input(indices)`. + * \param input The input tensor. + * \param indices The indices of the element with respect to which to differentiate. + */ + explicit JacobianMutator(Tensor input, Array indices) + : input_(input), indices_(indices) {} + /*! + * \brief Differentiate wrt the input variable. + * \param input The input variable. + */ + explicit JacobianMutator(VarExpr input) + : input_var_(input) {} + + Expr Mutate_(const Variable* op, const Expr& e) { + if (input_var_.operator->() && input_var_.get() == op) { + return FloatImm::make(op->type, 1.0); + } else { + return make_zero(op->type); + } + } + + Expr Mutate_(const Load* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const Let* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const Call* op, const Expr& e) { + if (op->call_type == Call::CallType::Halide) { + if (input_.operator->() && op->func.same_as(input_->op) && + op->value_index == input_->value_index) { + Expr condition = const_true(); + for (size_t i = 0; i < input_.ndim(); ++i) { + condition = And::make(condition, EQ::make(indices_[i], op->args[i])); + } + return Cast::make(op->type, condition); + } else { + return make_zero(op->type); + } + } else if (op->call_type == Call::CallType::PureIntrinsic) { + if (op->name == "exp") { + return Mul::make(Mutate(op->args[0]), e); + } else if (op->name == "log") { + return Div::make(Mutate(op->args[0]), op->args[0]); + } else if (op->name == "sigmoid") { + return Mul::make(Mutate(op->args[0]), + Mul::make(e, Sub::make(FloatImm::make(e.type(), 1.0), e))); + } else if (op->name == "tanh") { + return Mul::make(Mutate(op->args[0]), + Sub::make(FloatImm::make(e.type(), 1.0), Mul::make(e, e))); + } else if (op->name == "fabs") { + auto type = op->args[0].type(); + return Mul::make(Mutate(op->args[0]), + Select::make(GE::make(op->args[0], make_zero(type)), + FloatImm::make(type, 1.0), FloatImm::make(type, -1.0))); + } else if (op->name == intrinsic::tvm_if_then_else) { + Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; + return Call::make(op->type, op->name, new_args, op->call_type, op->func, op->value_index); + } else { + throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); + } + } + NOT_IMPLEMENTED + } + + Expr Mutate_(const Add* op, const Expr& e) { + return op->make(Mutate(op->a), Mutate(op->b)); + } + + Expr Mutate_(const Sub* op, const Expr& e) { + return op->make(Mutate(op->a), Mutate(op->b)); + } + + Expr Mutate_(const Mul* op, const Expr& e) { + return Add::make(Mul::make(Mutate(op->a), op->b), Mul::make(op->a, Mutate(op->b))); + } + + Expr Mutate_(const Div* op, const Expr& e) { + return Div::make( + Sub::make(Mul::make(Mutate(op->a), op->b), Mul::make(op->a, Mutate(op->b))), + Mul::make(op->b, op->b)); + } + + Expr Mutate_(const Mod* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const Min* op, const Expr& e) { + return Select::make(LE::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); + } + + Expr Mutate_(const Max* op, const Expr& e) { + return Select::make(GE::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); + } + + Expr Mutate_(const EQ* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const NE* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const LT* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const LE* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const GT* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const GE* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const And* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const Or* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const Reduce*, const Expr& e) { + // This case is relatively difficult because a reduction expression + // may use an arbitrary combiner. + // The resulting reduction expression will return a tuple containing + // both derivatives and the original results (in exactly this order). + + // We have to clone the reduction axes because otherwise the original expression + // cannot be used together with the derivative (it will lead to errors during lowering) + Expr expr_with_new_axes = CloneReduction(e); + const Reduce* op = expr_with_new_axes.as(); + + // New lhs and rhs variables of the new combiner consist of variables + // representing derivatives followed by the original variables. + Array new_lhs; + for (const auto& var : op->combiner->lhs) { + new_lhs.push_back(var.copy_with_suffix(".der")); + } + for (const auto& var : op->combiner->lhs) { + new_lhs.push_back(var); + } + + Array new_rhs; + for (const auto& var : op->combiner->rhs) { + new_rhs.push_back(var.copy_with_suffix(".der")); + } + for (const auto& var : op->combiner->rhs) { + new_rhs.push_back(var); + } + + // The new combiner result also consists of the resulting derivatives + // followed by the original results. + Array new_result; + for (const auto& res : op->combiner->result) { + // Each resulting derivative is computed as a sum of derivatives + // wrt lhs and rhs multiplied by the derivatives of lhs and rhs + Expr new_res = make_zero(res.type()); + for (size_t i = 0; i < op->combiner->lhs.size(); ++i) { + Expr res_di = Derivative(res, op->combiner->lhs[i]); + // new_lhs[i] is the derivative of lhs[i] (wrt our input tensor) + new_res = Add::make(new_res, Mul::make(new_lhs[i], res_di)); + } + for (size_t i = 0; i < op->combiner->rhs.size(); ++i) { + Expr res_di = Derivative(res, op->combiner->rhs[i]); + new_res = Add::make(new_res, Mul::make(new_rhs[i], res_di)); + } + new_result.push_back(new_res); + } + for (const auto& res : op->combiner->result) { + new_result.push_back(res); + } + + // The identity is transformed in a similar way + Array new_identity; + for (const auto& id : op->combiner->identity_element) { + new_identity.push_back(Mutate(id)); + } + for (const auto& id : op->combiner->identity_element) { + new_identity.push_back(id); + } + + Array new_source; + for (const auto& src : op->source) { + new_source.push_back(Mutate(src)); + } + for (const auto& src : op->source) { + new_source.push_back(src); + } + + CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + // Also simplify the resulting combiner (mostly to get rid of unused components) + return Simplify( + Reduce::make(new_combiner, new_source, op->axis, op->condition, op->value_index)); + } + + Expr Mutate_(const Cast* op, const Expr& e) { + if (op->type.is_float()) { + return Cast::make(op->type, Mutate(op->value)); + } else { + return make_zero(op->type); + } + } + + Expr Mutate_(const Not* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const Select* op, const Expr& e) { + return Select::make(op->condition, Mutate(op->true_value), Mutate(op->false_value)); + } + + Expr Mutate_(const Ramp* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const Broadcast* op, const Expr& e) NOT_IMPLEMENTED + + Expr Mutate_(const IntImm* op, const Expr& e) { return op->make(op->type, 0); } + Expr Mutate_(const UIntImm* op, const Expr& e) { return op->make(op->type, 0); } + Expr Mutate_(const FloatImm* op, const Expr& e) { return op->make(op->type, 0); } + + Expr Mutate_(const StringImm* op, const Expr& e) NOT_IMPLEMENTED + Expr Mutate_(const Shuffle* op, const Expr& e) NOT_IMPLEMENTED + + private: + Tensor input_; + Array indices_; + VarExpr input_var_; +}; + +Expr Jacobian(const Expr& expr, const Tensor& input, const Array& indices) { + return JacobianMutator(input, indices).Mutate(expr); +} + +Expr Derivative(const Expr& expr, const VarExpr& var) { + return JacobianMutator(var).Mutate(expr); +} + +Tensor Jacobian(const Tensor& output, const Tensor& input, bool optimize) { + if (const ComputeOpNode* op = output->op.as()) { + // We have to clone the iteration axes because otherwise the original expression + // cannot be used together with the derivative (it will lead to errors during lowering) + Array new_axis; + std::unordered_map vmap; + for (IterVar iv : op->axis) { + IterVar new_v = + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), + iv->iter_type, iv->thread_tag); + new_axis.push_back(new_v); + vmap[iv->var.operator->()] = new_v; + } + + // Generate new itervars for the input + Array input_itervars; + size_t i = 0; + for (Expr ext : input->shape) { + IterVar new_v = + IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i)), + IterVarType::kDataPar); + // Append them to new_axis + new_axis.push_back(new_v); + // We also need a separate array of these itervars + input_itervars.push_back(new_v); + ++i; + } + + // The differentiation itself happens here + Expr new_body = + Jacobian(Substitute(op->body[output->value_index], vmap), input, input_itervars); + new_body = Simplify(new_body); + + int value_index = 0; + Array new_bodies; + + // If this is a reduction then it may return a tuple and we have + // to repeat the body several times + if (const Reduce* red = new_body.as()) { + value_index = red->value_index; + for (size_t i = 0; i < red->source.size(); ++i) { + new_bodies.push_back( + Reduce::make(red->combiner, red->source, red->axis, red->condition, i)); + } + } else { + new_bodies.push_back(new_body); + } + + auto new_op = + ComputeOpNode::make(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); + + // new_shape = output.shape + input.shape + Array new_shape = output->shape; + for (const auto& e : input->shape) { + new_shape.push_back(e); + } + + Tensor tensor = TensorNode::make(new_shape, output->dtype, new_op, value_index); + + if (optimize) { + tensor = OptimizeAndLiftNonzeronessConditions(tensor); + } + + return tensor; + } else { + NOT_IMPLEMENTED; + } +} + +Tensor DiffBuildingBlock(const Tensor& output, const Tensor& input, const Tensor& head) { + Tensor jac_output_input = Jacobian(output, input); + Tensor result = topi::tensordot(head, jac_output_input, output->shape.size(), + output->op->name + "." + input->op->name + ".grad"); + // TODO(sgrechanik-h): Here we inline only jac_output_input because otherwise there will be + // performance problems. A better solution would be to inline smartly. + result = InlineTensors(result, {jac_output_input}); + result = OptimizeAndLiftNonzeronessConditions(result); + result = InlineTailCall(result); + return result; +} + +DifferentiationResult Differentiate(const Tensor& output, + const Array& inputs, + const Tensor& head_or_null, + const FDiffBuildingBlock& fdiff) { + Tensor head = head_or_null; + + // If the head is a null pointer, create an identity tensor + if (!head.get()) { + Array shape = output->shape; + for (auto e : output->shape) { + shape.push_back(e); + } + auto func = + [&output](const Array& input_indices) { + Expr res = const_true(); + for (size_t i = 0; i < output->shape.size(); ++i) { + res = res && Expr(input_indices[i]) == Expr(input_indices[output->shape.size() + i]); + } + return Cast::make(output->dtype, res); + }; + head = tvm::compute(shape, func, "identity"); + } + + // This map maps a tensor to the list of tensors immediately depending on it (using it in their + // bodies) + std::unordered_map> reverse_dependencies; + + // Collect reverse dependencies + std::vector stack({output}); + while (!stack.empty()) { + Tensor tensor = stack.back(); + stack.pop_back(); + + for (const Tensor& child : tensor->op->InputTensors()) { + if (!reverse_dependencies.count(child)) { + stack.push_back(child); + } + reverse_dependencies[child].push_back(tensor); + } + } + + // Individual summands of the adjoints + std::unordered_map> summands; + + // This map maps tensors to the corresponding adjoints (dLoss/dTensor) + std::unordered_map adjoints; + // head is the adjoint of output by definition + adjoints[output] = head; + + // This is a recursive function that does all the work. It computes the adjoint for a given + // tensor, adds it to the map, and returns it + std::function compute_adjoint; + compute_adjoint = + [&compute_adjoint, &adjoints, &summands, &reverse_dependencies, &fdiff, &head, &output] + (const Tensor& tensor) { + if (!adjoints.count(tensor)) { + // Here the adjoint hasn't been computed yet + Tensor res_adjoint; + std::vector deps = reverse_dependencies[tensor]; + if (deps.empty()) { + // No reverse dependencies means that the output does not depend on this tensor, + // return a zero tensor of the appropriate shape + Array result_shape(head->shape.begin(), + head->shape.end() + (-output->shape.size())); + for (auto e : tensor->shape) { + result_shape.push_back(e); + } + res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); + } else { + // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied + // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian + // and the multiplication is done in the function fdiff (DiffBuildingBlock by default). + for (const Tensor& dep : deps) { + Tensor part = fdiff(dep, tensor, compute_adjoint(dep)); + res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; + + // Add this part to summands + auto& summands_of_adjoint = summands[tensor]; + if (summands_of_adjoint.get()) { + summands_of_adjoint.Set(dep, part); + } else { + summands_of_adjoint = Map({{dep, part}}); + } + } + } + + adjoints[tensor] = res_adjoint; + return res_adjoint; + } else { + return adjoints[tensor]; + } + }; + + // Adjoints corresponding to inputs + Array result; + + // If inputs is empty, compute adjoints for all tensors, on which output depends + if (inputs.empty()) { + for (const auto& dep : reverse_dependencies) { + compute_adjoint(dep.first); + } + } + + // Compute an adjoint for each input + for (const Tensor& input : inputs) { + result.push_back(compute_adjoint(input)); + } + + return DifferentiationResultNode::make(result, adjoints, summands); +} + + +TVM_REGISTER_API("autodiff.Jacobian") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() > 2) { + *ret = Jacobian(args[0], args[1], args[2].operator bool()); + } else { + *ret = Jacobian(args[0], args[1]); + } + }); + +TVM_REGISTER_API("autodiff.Derivative") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = Derivative(args[0], args[1]); + }); + +TVM_REGISTER_API("autodiff.DiffBuildingBlock") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = DiffBuildingBlock(args[0], args[1], args[2]); + }); + +TVM_REGISTER_API("autodiff.Differentiate") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() <= 1) { + *ret = Differentiate(args[0]); + } else if (args.size() == 2) { + *ret = Differentiate(args[0], args[1]); + } else if (args.size() == 3) { + *ret = Differentiate(args[0], args[1], args[2]); + } else if (args.size() >= 4) { + auto pfunc = args[3].operator PackedFunc(); + auto fdiff = + [pfunc](const Tensor& o, const Tensor& i, const Tensor& h) { + return pfunc(o, i, h); + }; + *ret = Differentiate(args[0], args[1], args[2], fdiff); + } + }); + +} // namespace ir +} // namespace tvm diff --git a/src/pass/autodiff.h b/src/pass/autodiff.h new file mode 100644 index 000000000000..43b0b8ebf23b --- /dev/null +++ b/src/pass/autodiff.h @@ -0,0 +1,146 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file autodiff.h + * \brief Automatic differentiation of IR Expr. + */ +#ifndef TVM_PASS_AUTODIFF_H_ +#define TVM_PASS_AUTODIFF_H_ + +#include +#include + +namespace tvm { +namespace ir { + +class DifferentiationResultNode; + +/*! + * \brief A result of differentiation. + */ +class DifferentiationResult : public NodeRef { + public: + /*! \brief default constructor, used internally */ + DifferentiationResult() {} + explicit DifferentiationResult(NodePtr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const DifferentiationResultNode* operator->() const; + /*! \brief specify container node */ + using ContainerType = DifferentiationResultNode; +}; + +/*! \brief Node to represent a differentiation result */ +class DifferentiationResultNode : public Node { + public: + /*! \brief The requested adjoints, i.e. Jacobians or gradients wrt to the given inputs */ + Array result; + /*! \brief A map from tensors to the corresponding adjoints (including internal nodes) */ + Map adjoints; + /*! \brief Single summands of the adjoints*/ + Map> adjoint_summands; + /*! \brief constructor */ + DifferentiationResultNode() {} + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("result", &result); + v->Visit("adjoints", &adjoints); + v->Visit("adjoint_summands", &adjoint_summands); + } + TVM_DLL static DifferentiationResult make(Array result, + Map adjoints, + Map> adjoint_summands); + + static constexpr const char* _type_key = "DifferentiationResult"; + TVM_DECLARE_NODE_TYPE_INFO(DifferentiationResultNode, Node); +}; + +inline const DifferentiationResultNode* DifferentiationResult::operator->() const { + return static_cast(node_.get()); +} + + +/*! \brief A type of a "local" differentiation function for reverse mode AD + * + * A function of this type is a building block for reverse-mode automatic differentiation. It + * should take three tensors: `output`, `input` and `head`, `head` being the adjoint corresponding + * to the `output`, and return (a summand of) the adjoint corresponding to the input. In other + * words, it should differentiate `output` wrt `input` and multiply the result by `head` with + * tensor dot product (`head` should be on the left of the multiplication). `input` should be an + * immediate dependency of `output` (should be called from within the body of `output`). + * + * See also ::DiffBuildingBlock, which might be considered the reference implementation. + */ +using FDiffBuildingBlock = std::function; + +/*! + * \brief Take the derivative of the expression with respect to the given variable. + * \param expr The expression to differentiate. + * \param var The variable to differentiate with respect to. + * \return The expression for the derivative. + */ +EXPORT Expr Derivative(const Expr& expr, const VarExpr& var); + +/*! + * \brief Get the tensor representing the Jacobian of the output with respect to the input. + * + * Note that if \p output depends on \p input indirectly (by using some other tensor + * depending on \p input), this dependency won't contribute to the resulting Jacobian. + * For such cases use the function ::Differentiate. + * + * \param output The tensor to differentiate. + * \param input The input tensor, which \p output should directly use. + * \param optimize Whether to perform optimizations like lifting of nonzeroness conditions. + * \return The tensor representing the Jacobian of shape `output.shape + input.shape`. + */ +EXPORT Tensor Jacobian(const Tensor& output, const Tensor& input, bool optimize = true); + +/*! + * \brief The building block for reverse-mode AD. + * + * Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor + * dot product. \p input must be an immediate dependency of \p output (must be called from within + * the body of \p output). That is, the function will compute a summand of the adjoint for \p input + * given the adjoint for \p output (which is called \p head here). + * + * \param output The tensor to differentiate. + * \param input The input tensor, which \p output should directly use. + * \param head The adjoint of \p output. Must be of shape `prefix + output.shape` + * \return The tensor representing the adjoint of \p input of shape `prefix + input.shape`. + */ +EXPORT Tensor DiffBuildingBlock(const Tensor& output, const Tensor& input, const Tensor& head); + +/*! + * \brief Perform reverse mode automatic differentiation. + * + * Each item of the `result` field of the result is an adjoint for the corresponding item of + * \p inputs, i.e. \p head multiplied by the Jacobian of \p output with respect to the + * corresponding item of \p inputs. + * + * \param output The tensor to differentiate. + * \param inputs The array of input tensors. When the array is empty, will perform differentiation + * wrt all tensors the output depends on. + * \param head The adjoint of the output, in other words, some tensor, by which the Jacobians + * will be multiplied. Its shape must be of the form `prefix + output.shape`. If the + * null pointer is provided, the identity tensor of shape + * `output.shape + output.shape` will be used. + * \param fdiff The function performing differentiation and multiplication, see + * ::FDiffBuildingBlock. + * \return An object of type DifferentiationResult which contains three fields: + * - `result` An array of adjoints corresponding to \p inputs. + * - `adjoints` A map from tensors to the corresponding adjoints (includes intermediate + * tensors). + * - `adjoint_summands` A map from tensors to maps from parent tensors to individual + * summands of the adjoint. + */ +EXPORT DifferentiationResult Differentiate(const Tensor& output, + const Array& inputs = Array(), + const Tensor& head = Tensor(), + const FDiffBuildingBlock& fdiff = DiffBuildingBlock); + +} // namespace ir +} // namespace tvm +#endif // TVM_PASS_AUTODIFF_H_ diff --git a/src/pass/dump_tensor.cc b/src/pass/dump_tensor.cc new file mode 100644 index 000000000000..18927e07d1de --- /dev/null +++ b/src/pass/dump_tensor.cc @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file dump_tensor.cc + * \brief Print out tensors recursively. + */ +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +std::string PrintTensorName(const Tensor& tensor) { + if (!tensor.get()) { + return "NULL_TENSOR"; + } + + std::ostringstream oss; + oss << tensor->op->name << "[" << tensor->value_index << "]"; + return oss.str(); +} + +std::string PrintIterVars(const Array& itervars) { + std::ostringstream oss; + oss << "("; + bool first = true; + for (const IterVar& iv : itervars) { + if (!first) oss << ", "; + first = false; + oss << iv->var << " : " << "[" << iv->dom->min + << ", " << (iv->dom->min + iv->dom->extent - 1) << "]"; + } + oss << ")"; + return oss.str(); +} + +std::string PrintTensorsRecursively(const Array& tensors) { + std::vector unprocessed; + std::unordered_set processed; + std::ostringstream oss; + + for (const Tensor& t : tensors) { + unprocessed.push_back(t); + } + + while (!unprocessed.empty()) { + Tensor cur = unprocessed.back(); + unprocessed.pop_back(); + processed.insert(cur); + + oss << "tensor " << PrintTensorName(cur) << " : " << cur->dtype << " " << cur->shape << "\n"; + if (const ComputeOpNode* comp = cur->op.as()) { + oss << "axes " << PrintIterVars(comp->axis) << "\n"; + Expr body = comp->body[cur->value_index]; + + for (const Tensor& t : comp->InputTensors()) { + if (processed.count(t) == 0) { + unprocessed.push_back(t); + } + } + + if (const Reduce* red = body.as()) { + oss << "Reduction\n"; + oss << " identity " << red->combiner->identity_element << "\n"; + oss << " lhs " << red->combiner->lhs << " rhs " << red->combiner->rhs << "\n"; + oss << " combiner " << red->combiner->result << "\n"; + oss << " axis " << PrintIterVars(red->axis) << "\n"; + oss << " condition " << red->condition << "\n"; + for (size_t i = 0; i < red->source.size(); ++i) { + oss << " source[" << i << "] = " << red->source[i] << "\n"; + } + } else { + oss << " " << body << "\n"; + } + } else { + oss << " " << cur->op << "\n"; + } + oss << "\n"; + } + + return oss.str(); +} + +std::string PrintTensorRecursively(const Tensor& tensor) { + return PrintTensorsRecursively({tensor}); +} + +TVM_REGISTER_API("PrintTensorRecursively") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = PrintTensorRecursively(args[0]); + }); + +TVM_REGISTER_API("PrintTensorsRecursively") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = PrintTensorsRecursively(args[0]); + }); + +} // namespace ir +} // namespace tvm diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc new file mode 100644 index 000000000000..327d8a47032e --- /dev/null +++ b/src/pass/zero_elimination.cc @@ -0,0 +1,1666 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file zero_elimination.cc + * \brief Transform tensors in such a way as to eliminate summation over zeros. + */ +#include "zero_elimination.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "arithmetic/ModulusRemainder.h" +#include "../op/op_util.h" + +namespace tvm { +namespace ir { + +using HalideIR::Internal::gcd; +using HalideIR::Internal::lcm; + +struct ExprLess { + bool operator()(const Expr& l, const Expr& r) const { + return Compare(l, r) < 0; + } +}; + +struct ExprEq { + bool operator()(const Expr& l, const Expr& r) const { + return Compare(l, r) == 0; + } +}; + +// Merge two maps, prefer the right one on conflict +template +Map Merge(Map original, const Map& update) { + for (const auto& p : update) { + original.Set(p.first, p.second); + } + return std::move(original); +} + +// Concatenate two arrays +template +Array Concat(Array a, const Array& b) { + for (const auto& x : b) { + a.push_back(x); + } + return std::move(a); +} + +// Combine all expressions from the container using &&. +template +Expr All(const container& c) { + Expr res; + for (const auto& e : c) { + if (res.get()) { + res = res && e; + } else { + res = e; + } + } + if (res.get()) { + return res; + } else { + return const_true(); + } +} + +// Create a select statement of the form cond ? on_true : 0 +Expr SelectElseZero(const Expr& cond, const Expr& on_true) { + return Select::make(cond, on_true, make_zero(on_true.type())); +} + +// Simplify the expression as thoroughly as possible by using all available simplifiers. +Expr SuperSimplify(Expr e, const Map& vranges = Map()) { + // For some reason no simplifier can detect that there is only one value of the variable + std::unordered_map vmap; + for (const auto& var_range : vranges) { + if (is_const_int(var_range.second->extent, 1)) { + vmap[var_range.first.get()] = var_range.second->min; + } + } + if (!vmap.empty()) { + e = Substitute(e, vmap); + } + + return CanonicalSimplify(Simplify(CanonicalSimplify(e, vranges), vranges), vranges); +} + +// Provability check that uses SuperSimplify +bool CanProve(Expr e, const Map& vranges = Map()) { + return is_one(SuperSimplify(e, vranges)); +} + +class ExprFreeVarsVisitor : public IRVisitor { + public: + std::vector free_array; + std::unordered_set bound; + std::unordered_set free; + + virtual void Visit(const NodeRef& node) { + if (const Variable* v = node.as()) { + if (!bound.count(v) && !free.count(v)) { + free.insert(v); + free_array.push_back(Var(node.node_)); + } + } else { + IRVisitor::Visit(node); + } + } + + void Visit_(const Variable* op) { + CHECK(false) << "This case shouldn't happen"; + } + + void Visit_(const LetStmt* op) { + bound.insert(op->var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const For* op) { + bound.insert(op->loop_var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const Let* op) { + bound.insert(op->var.get()); + IRVisitor::Visit_(op); + } + + void Visit_(const Reduce* op) { + for (const auto& iv : op->axis) { + bound.insert(iv->var.get()); + } + IRVisitor::Visit_(op); + } + + void Visit_(const Store* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Allocate* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Free* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } + + void Visit_(const Load* op) { + Visit(op->buffer_var); + IRVisitor::Visit_(op); + } +}; + +// Get free variables of an expression +Array ExprFreeVars(const Expr& expr) { + ExprFreeVarsVisitor visitor; + visitor.Visit(expr); + return visitor.free_array; +} + +// Clone iter vars and return both the new vars and the substitution from old to new. +std::pair, std::unordered_map> CloneIterVars( + const Array& vars) { + Array new_vars; + std::unordered_map vmap; + for (const IterVar& iv : vars) { + IterVar new_v = + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), + iv->iter_type, iv->thread_tag); + new_vars.push_back(new_v); + vmap[iv->var.get()] = new_v; + } + return std::make_pair(std::move(new_vars), std::move(vmap)); +} + +// Clone reduction by cloning the axis variables. +Expr CloneReduction(const Expr& expr) { + if (const Reduce* red = expr.as()) { + Array new_axis; + std::unordered_map vmap; + std::tie(new_axis, vmap) = CloneIterVars(red->axis); + + Array src_with_newaxis; + for (const auto& src : red->source) { + src_with_newaxis.push_back(Substitute(src, vmap)); + } + + return Reduce::make(red->combiner, src_with_newaxis, + new_axis, Substitute(red->condition, vmap), red->value_index); + } else { + return expr; + } +} + +// Convert an array of itervars to an array of inequalities +Array IterVarsToInequalities(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(GE::make(v->var, v->dom->min)); + res.push_back(LT::make(v->var, v->dom->min + v->dom->extent)); + } + return res; +} + +// Convert an array of itervars to a map from vars to ranges +Map IterVarsToMap(const Array& itervars) { + Map res; + for (const IterVar& v : itervars) { + res.Set(v->var, v->dom); + } + return res; +} + +// Convert an array of itervars to an array of vars +Array IterVarsToVars(const Array& itervars) { + Array res; + for (const IterVar& v : itervars) { + res.push_back(v->var); + } + return res; +} + +// Given a map from vars to ranges create an array of itervars +Array IterVarsFromMap(const Array& vars, const Map& vranges, + IterVarType iter_type = kDataPar, std::string thread_tag = "") { + Array res; + for (const Var& v : vars) { + res.push_back(IterVarNode::make(vranges[v], v, iter_type, thread_tag)); + } + return res; +} + +// Return true if this combiner is just a sum. +bool IsSumCombiner(const CommReducer& combiner) { + if (combiner->result.size() != 1) { + return false; + } + + if (!is_const_value(SuperSimplify(combiner->identity_element[0]), 0)) { + return false; + } + + return is_const_value(SuperSimplify(combiner->result[0] - + (combiner->lhs[0] + combiner->rhs[0])), + 0); +} + +// Return true if zero may be factored out of a reduction with this combiner. +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index) { + if (!is_const_value(combiner->identity_element[value_index], 0)) { + return false; + } + + Expr zero = make_zero(combiner->result[value_index].type()); + Expr in = Substitute(combiner->result[value_index], + {{combiner->lhs[value_index], zero}, + {combiner->rhs[value_index], zero}}); + in = SuperSimplify(in); + + return is_const_value(in, 0); +} + +Expr InlineThisCall(const Expr& expr) { + if (const Call* op = expr.as()) { + if (op->call_type == Call::CallType::Halide) { + if (const ComputeOpNode* op_comp = op->func.as()) { + Array tensor_axes; + for (const auto& var : op_comp->axis) { + tensor_axes.push_back(var->var); + } + + Stmt inlined = Inline(Evaluate::make(expr), op->func, tensor_axes, + op_comp->body[op->value_index]); + if (const ir::Evaluate* ev = inlined.as()) { + // If it is a reduction, clone it + return CloneReduction(ev->value); + } + } + } + } + + return expr; +} + +Tensor InlineTailCall(const Tensor& tensor) { + return op::TransformBody(tensor, InlineThisCall); +} + +class InlineTensorsMutator : public IRMutator { + public: + explicit InlineTensorsMutator(const Array& inlineable, bool inline_reductions = false) + : inline_reductions_(inline_reductions) { + for (const Tensor& tensor : inlineable) { + inlineable_.emplace(tensor->op.operator->(), tensor->value_index); + } + } + + Expr Mutate_(const Call* op, const Expr& e) { + if (op->call_type == Call::CallType::Halide) { + const ComputeOpNode* op_comp = op->func.as(); + if (inlineable_.empty() || inlineable_.count({op_comp, op->value_index})) { + if (op_comp && (inline_reductions_ || !op_comp->body[0].as())) { + Array tensor_axes; + for (const auto& var : op_comp->axis) { + tensor_axes.push_back(var->var); + } + + Stmt inlined = Inline(Evaluate::make(e), op->func, tensor_axes, + op_comp->body[op->value_index]); + if (const ir::Evaluate* ev = inlined.as()) { + // If it is a reduction, clone it + return Mutate(ev->value); + } + } + } + } + + return e; + } + + private: + std::set> inlineable_; + bool inline_reductions_; +}; + +Expr InlineTensors(const Expr& expr, const Array& inlineable, + bool inline_reductions) { + return InlineTensorsMutator(inlineable, inline_reductions).Mutate(expr); +} + +Tensor InlineTensors(const Tensor& tensor, const Array& inlineable, + bool inline_reductions) { + auto transformation = + [inlineable, inline_reductions](const Expr& e) { + return InlineTensorsMutator(inlineable, inline_reductions).Mutate(e); }; + return op::TransformBody(tensor, transformation); +} + + +struct NonzeronessConditionResult { + Expr cond; + Expr value; + + Expr to_expr() const { + return SelectElseZero(cond, value); + } +}; + +class NonzeronessConditionFunctor + : public ExprFunctor { + public: + NonzeronessConditionResult NonzeronessCondition(const Expr& e) { + return VisitExpr(e, e); + } + + result_type VisitExpr_(const Variable*, const Expr& e) final { return Default_(e); } + result_type VisitExpr_(const IntImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const UIntImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const FloatImm* op, const Expr& e) final { return Const_(op, e); } + result_type VisitExpr_(const StringImm*, const Expr& e) final { return Default_(e); } + result_type VisitExpr_(const Add* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Sub* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Mul* op, const Expr& e) final { return BinOpMulLike_(op, e); } + result_type VisitExpr_(const Div* op, const Expr& e) final { return BinOpDivLike_(op, e); } + result_type VisitExpr_(const Mod* op, const Expr& e) final { return BinOpDivLike_(op, e); } + result_type VisitExpr_(const Min* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const Max* op, const Expr& e) final { return BinOpAddLike_(op, e); } + + result_type VisitExpr_(const Cast* op, const Expr& e) final { + if (op->value.type().is_bool()) { + return {op->value, make_const(e.type(), 1)}; + } else { + auto nz_a = NonzeronessCondition(op->value); + + if (nz_a.value.same_as(op->value)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, Cast::make(op->type, nz_a.value)}; + } + } + } + + result_type VisitExpr_(const Select* op, const Expr& e) final { + return SelectLike_(e, op->condition, op->true_value, op->false_value, Select::make); + } + + result_type VisitExpr_(const Call* op, const Expr& e) final { + if (op->name == intrinsic::tvm_if_then_else) { + return SelectLike_(e, op->args[0], op->args[1], op->args[2], if_then_else); + } else { + return Default_(e); + } + } + + NonzeronessConditionResult Default_(const Expr& e) { + return {const_true(), e}; + } + + template + NonzeronessConditionResult Const_(const TNode* op, const Expr& e) { + if (op->value == 0) { + return {const_false(), e}; + } else { + return {const_true(), e}; + } + } + + template + NonzeronessConditionResult SelectLike_(const Expr& e, const Expr& cond, const Expr& true_val, + const Expr& false_val, make_select_type make_select) { + auto nz_a = NonzeronessCondition(true_val); + auto nz_b = NonzeronessCondition(false_val); + + if (is_const_value(nz_b.value, 0)) { + Expr new_cond = SuperSimplify(nz_a.cond && cond); + return {new_cond, nz_a.value}; + } + + if (is_const_value(nz_a.value, 0)) { + Expr new_cond = SuperSimplify(nz_b.cond && !cond); + return {new_cond, nz_b.value}; + } + + Expr new_cond = + SuperSimplify(Or::make(cond && nz_a.cond, + !cond && nz_b.cond)); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, e}; + } else { + return {new_cond, make_select(cond, nz_a.value, nz_b.value)}; + } + } + + template + NonzeronessConditionResult BinOpAddLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + auto nz_b = NonzeronessCondition(op->b); + + if (Equal(nz_a.cond, nz_b.cond)) { + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, TNode::make(nz_a.value, nz_b.value)}; + } + } else { + Expr new_cond = SuperSimplify(Or::make(nz_a.cond, nz_b.cond)); + Expr new_a = Equal(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); + Expr new_b = Equal(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); + Expr new_expr = TNode::make(new_a, new_b); + return {new_cond, new_expr}; + } + } + + template + NonzeronessConditionResult BinOpMulLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + auto nz_b = NonzeronessCondition(op->b); + + Expr new_cond = SuperSimplify(nz_a.cond && nz_b.cond); + + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {new_cond, e}; + } else { + return {new_cond, TNode::make(nz_a.value, nz_b.value)}; + } + } + + template + NonzeronessConditionResult BinOpDivLike_(const TNode* op, const Expr& e) { + auto nz_a = NonzeronessCondition(op->a); + + if (nz_a.value.same_as(op->a)) { + return {nz_a.cond, e}; + } else { + return {nz_a.cond, TNode::make(nz_a.value, op->b)}; + } + } +}; + +NonzeronessConditionResult NonzeronessCondition(const Expr& expr) { + return NonzeronessConditionFunctor().NonzeronessCondition(expr); +} + +Expr LiftNonzeronessCondition(const Expr& expr) { + return NonzeronessCondition(expr).to_expr(); +} + + +class NormalizeComparisonsMutator : public IRMutator { + public: + virtual Expr Mutate_(const EQ* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return Make(op->a, op->b); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return Make(op->b, op->a); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return Make(op->b, op->a); } + + private: + template + Expr Make(const Expr& a, const Expr& b) { + // rewrite LT to LE for ints + if (std::is_same::value && (a.type().is_int() || a.type().is_uint())) { + return LE::make(SuperSimplify(a - b + 1), make_zero(a.type())); + } + return TNode::make(SuperSimplify(a - b), make_zero(a.type())); + } +}; + +// Rewrite every comparison into the form a == 0, a != 0, a <= 0, and sometimes for floats a < 0 +Expr NormalizeComparisons(const Expr& expr) { + return NormalizeComparisonsMutator().Mutate(expr); +} + + +struct FactorOutAtomicFormulasResult { + std::vector atomic_formulas; + Expr rest; + + Expr to_expr() const { + Expr res = rest; + for (const Expr& e : atomic_formulas) { + res = And::make(e, res); + } + return res; + } +}; + +class FactorOutAtomicFormulasFunctor + : public ExprFunctor { + public: + result_type Atomic_(const Expr& e) { + return {{e}, make_const(e.type(), 1)}; + } + + result_type VisitExpr_(const Variable*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const Call*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const IntImm*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const UIntImm*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const EQ*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const NE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const LE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const LT*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const GE*, const Expr& e) final { return Atomic_(e); } + result_type VisitExpr_(const GT*, const Expr& e) final { return Atomic_(e); } + + result_type VisitExpr_(const And* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + return {res, res_a.rest && res_b.rest}; + } + + result_type VisitExpr_(const Mul* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::set_union(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + return {res, res_a.rest * res_b.rest}; + } + + result_type VisitExpr_(const Or* op, const Expr& e) final { + auto res_a = VisitExpr(op->a, op->a); + auto res_b = VisitExpr(op->b, op->b); + + std::vector res; + res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); + std::set_intersection(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::back_inserter(res), + ExprLess()); + + std::vector new_cond_a; + new_cond_a.reserve(res_a.atomic_formulas.size() - res.size()); + std::set_difference(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + res.begin(), res.end(), + std::back_inserter(new_cond_a), + ExprLess()); + + std::vector new_cond_b; + new_cond_b.reserve(res_b.atomic_formulas.size() - res.size()); + std::set_difference(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + res.begin(), res.end(), + std::back_inserter(new_cond_b), + ExprLess()); + + res_a.atomic_formulas = std::move(new_cond_a); + res_b.atomic_formulas = std::move(new_cond_b); + + Expr new_rest = Or::make(res_a.to_expr(), res_b.to_expr()); + + return {res, new_rest}; + } +}; + +// Transform the given formula into an array of atomic formulas and a non-atomic residual. +FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const Expr& e) { + return FactorOutAtomicFormulasFunctor().VisitExpr(e, e); +} + + +struct EliminateDivModResult { + Expr expr; + Map substitution; + Array new_variables; + Array conditions; + Map ranges; +}; + +class EliminateDivModMutator : public IRMutator { + public: + Map substitution; + Array new_variables; + Array conditions; + Map ranges; + + explicit EliminateDivModMutator(Map ranges) + : ranges(ranges) {} + + virtual Expr Mutate_(const Div* op, const Expr& e) { + const IntImm* imm = op->b.as(); + if (imm && imm->value > 0) { + auto it = expr_to_vars_.find({op->a.get(), imm->value}); + if (it != expr_to_vars_.end()) { + return it->second.first; + } + + Expr mutated_a = Mutate(op->a); + return AddNewVarPair(op->a, mutated_a, imm->value).first; + } + + return Div::make(Mutate(op->a), Mutate(op->b)); + } + + virtual Expr Mutate_(const Mod* op, const Expr& e) { + const IntImm* imm = op->b.as(); + if (imm && imm->value > 0) { + auto it = expr_to_vars_.find({op->a.get(), imm->value}); + if (it != expr_to_vars_.end()) { + return it->second.second; + } + + Expr mutated_a = Mutate(op->a); + return AddNewVarPair(op->a, mutated_a, imm->value).second; + } + + return Mod::make(Mutate(op->a), Mutate(op->b)); + } + + private: + std::pair AddNewVarPair(const Expr& e, const Expr& mut, int64_t val) { + Expr val_e = make_const(e.type(), val); + idx_ += 1; + + auto div = Var("div" + std::to_string(idx_), e.type()); + auto mod = Var("mod" + std::to_string(idx_), e.type()); + + new_variables.push_back(div); + new_variables.push_back(mod); + + substitution.Set(div, mut / val_e); + substitution.Set(mod, mut % val_e); + + std::unordered_map var_intsets; + for (const auto& p : ranges) { + var_intsets[p.first.get()] = IntSet::range(p.second); + } + + Range div_range = EvalSet(mut / val_e, var_intsets).cover_range(Range()); + Range mod_range = EvalSet(mut % val_e, var_intsets).cover_range(Range()); + + ranges.Set(div, div_range); + ranges.Set(mod, mod_range); + + conditions.push_back(mut == div*val_e + mod); + + if (!CanProve(mod_range->extent <= val_e)) { + LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod of expr " << e + << " (probably it may change its sign)"; + // We cannot prove that mod is unique, so add additional condition + conditions.push_back(Select::make(e >= 0, mod >= 0, mod <= 0)); + } + + auto p = std::make_pair(div, mod); + expr_to_vars_[{e.get(), val}] = p; + return p; + } + + int idx_{0}; + std::map, std::pair> + expr_to_vars_; +}; + +// replace every subexpr of the form e/const and e % const with a new variable +EliminateDivModResult EliminateDivMod(const Expr& expr, Map ranges) { + EliminateDivModResult res; + EliminateDivModMutator mutator(ranges); + res.expr = mutator.Mutate(expr); + res.conditions = std::move(mutator.conditions); + res.new_variables = std::move(mutator.new_variables); + res.substitution = std::move(mutator.substitution); + res.ranges = std::move(mutator.ranges); + return res; +} + +// run EliminateDivMod from the condition of a reduction +Expr EliminateDivModFromReductionCondition(const Expr& expr, + Map vranges = Map()) { + if (const Reduce* red = expr.as()) { + for (const IterVar& iv : red->axis) { + vranges.Set(iv->var, iv->dom); + } + + auto elim_res = EliminateDivMod(red->condition, vranges); + + vranges = elim_res.ranges; + + Array new_axis = + Concat(red->axis, IterVarsFromMap(elim_res.new_variables, vranges, kCommReduce)); + + Expr new_cond = elim_res.expr && All(elim_res.conditions); + + return Reduce::make(red->combiner, red->source, new_axis, new_cond, red->value_index); + } else { + return expr; + } +} + + +VarBounds VarBounds::substitute(const Map& subst) const { + auto apply_fun = [&subst](const Expr& e) { return Substitute(e, subst); }; + return {Substitute(coef, subst), + UpdateArray(lower, apply_fun), + UpdateArray(equal, apply_fun), + UpdateArray(upper, apply_fun)}; +} + +Array SolveSystemOfInequalitiesResult::as_conditions() const { + Array res; + for (const Var& v : variables) { + auto it = bounds.find(v.get()); + CHECK(it != bounds.end()); + const VarBounds& bnds = it->second; + Expr lhs = bnds.coef * v; + for (const Expr& rhs : bnds.equal) { + res.push_back(EQ::make(lhs, rhs)); + } + for (const Expr& rhs : bnds.lower) { + res.push_back(GE::make(lhs, rhs)); + } + for (const Expr& rhs : bnds.upper) { + res.push_back(LE::make(lhs, rhs)); + } + } + for (const Expr& e : other_conditions) { + res.push_back(e); + } + return res; +} + +// Rewrite the system of inequalities using Fourier-Motzkin elimination +// Note that variable ranges help a lot, so this parameter is even non-optional +SolveSystemOfInequalitiesResult SolveSystemOfInequalities(const Array& inequalities, + const Array& variables, + const Map& vranges) { + SolveSystemOfInequalitiesResult res; + res.variables = variables; + + // The algorithm consists in doing the following things for each variable v + // - Take formulas from `current` and classify them according to polarity wrt v + // - Combine each formula of positive polarity (wrt v) with each formula of negative polarity + // - Put the resulting combinations into `new_current` along with unclassifiable formulas + // - Replace `current` with `new_current` and move to the next variable + + // current and new_current are sorted to enable some heuristics + std::set current; + std::set new_current; + // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_pos; + // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 + std::vector> coef_neg; + + // formulas we don't know what to do with + std::vector rest; + + // A helper that adds an inequality to new_current if it's not obviously redundant + auto add_to_new_current = [&new_current, &vranges] (const Expr& new_ineq) { + if (CanProve(new_ineq, vranges)) { + // redundant: follows from the vranges + return; + } + if (const LE* new_le = new_ineq.as()) { + // A heuristic: check if the new inequality is a consequence of one + // of its future neighbors (in this case don't add it) or if a future neighbor is + // a consequence of the new ineq (in which case remove the neighbor) + auto it_neighbor = new_current.lower_bound(new_ineq); + if (it_neighbor != new_current.begin()) { + const LE* le = std::prev(it_neighbor)->as(); + if (le && CanProve(new_le->a - le->a <= 0, vranges)) { + return; + } else if (le && CanProve(le->a - new_le->a <= 0, vranges)) { + new_current.erase(std::prev(it_neighbor)); + } + } + // Check the other neighbor + if (it_neighbor != new_current.end()) { + const LE* le = it_neighbor->as(); + if (le && CanProve(new_le->a - le->a <= 0, vranges)) { + return; + } else if (le && CanProve(le->a - new_le->a <= 0, vranges)) { + it_neighbor = new_current.erase(it_neighbor); + } + } + + new_current.insert(it_neighbor, new_ineq); + } else { + new_current.insert(new_ineq); + } + }; + + // Simplify each inequality into the form `expr <= 0` and add to new_current formulas + for (const Expr& ineq : inequalities) { + add_to_new_current(NormalizeComparisons(SuperSimplify(ineq, vranges))); + } + + std::swap(current, new_current); + + for (const Var& v : variables) { + CHECK(!res.bounds.count(v.get())) << + "Variable " << v << " appears several times in the `variables` which might be a bug"; + + new_current.clear(); + coef_pos.clear(); + coef_neg.clear(); + + // Add bounds from vranges + if (vranges.count(v)) { + const Range& range = vranges[v]; + Expr range_lbound = SuperSimplify(range->min, vranges); + Expr range_ubound = SuperSimplify(range->min + range->extent - 1, vranges); + coef_neg.push_back({-1, range_lbound}); + coef_pos.push_back({1, -range_ubound}); + } + + // Take formulas from `current` and classify them according to polarity wrt v + for (const Expr& ineq : current) { + if (const LE* le = ineq.as()) { + Array coef = arith::DetectLinearEquation(le->a, {v}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + add_to_new_current(ineq); + } else if (coef0 > 0) { + coef_pos.push_back({coef0, coef[1]}); + } else if (coef0 < 0) { + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } else if (const EQ* eq = ineq.as()) { + Array coef = arith::DetectLinearEquation(eq->a, {v}); + if (!coef.empty() && is_const(coef[0])) { + int64_t coef0 = *as_const_int(coef[0]); + if (coef0 == 0) { + // zero polarity, straight to new_current + add_to_new_current(ineq); + } else if (coef0 > 0) { + // Equalities may be considered as pairs of two inequalities + coef_pos.push_back({coef0, coef[1]}); + coef_neg.push_back({-coef0, -coef[1]}); + } else if (coef0 < 0) { + coef_pos.push_back({-coef0, -coef[1]}); + coef_neg.push_back({coef0, coef[1]}); + } + continue; + } + } + + // if nothing worked, put it in rest + rest.push_back(ineq); + } + + // Combine each positive inequality with each negative one (by adding them together) + for (const auto& pos : coef_pos) { + for (const auto& neg : coef_neg) { + auto first_gcd = gcd(pos.first, -neg.first); + Expr c_pos = make_const(v.type(), neg.first/first_gcd); + Expr c_neg = make_const(v.type(), pos.first/first_gcd); + Expr new_lhs = c_neg*neg.second - c_pos*pos.second; + Expr new_ineq = LE::make(new_lhs, make_zero(pos.second.type())); + new_ineq = NormalizeComparisons(SuperSimplify(new_ineq, vranges)); + add_to_new_current(new_ineq); + } + } + + // Now we have to generate resulting (in)equalities for the variable v + + // Find the common denominator in a sense + // We will generate formulas of the form coef_lcm*v <= bound + int64_t coef_lcm = 1; + for (const auto& pos : coef_pos) { + coef_lcm = lcm(coef_lcm, pos.first); + } + for (const auto& neg : coef_neg) { + coef_lcm = lcm(coef_lcm, -neg.first); + } + + // The resulting lower and upper bounds stored in sorted vectors + std::vector upper_bounds; + std::vector lower_bounds; + upper_bounds.reserve(coef_pos.size()); + lower_bounds.reserve(coef_neg.size()); + + for (const auto& pos : coef_pos) { + Expr bound = make_const(v.type(), -coef_lcm/pos.first)*pos.second; + bound = SuperSimplify(bound, vranges); + // Don't add if any of the existing bounds is better + if (std::any_of(upper_bounds.begin(), upper_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound <= 0, + vranges); })) { + continue; + } + // Erase all worse bounds + upper_bounds.erase( + std::remove_if(upper_bounds.begin(), upper_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound >= 0, + vranges); }), + upper_bounds.end()); + // Add + upper_bounds.push_back(bound); + } + for (const auto& neg : coef_neg) { + Expr bound = make_const(v.type(), -coef_lcm/neg.first)*neg.second; + bound = SuperSimplify(bound, vranges); + // Don't add if any of the existing bounds is better + if (std::any_of(lower_bounds.begin(), lower_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound >= 0, + vranges); })) { + continue; + } + // Erase all worse bounds + lower_bounds.erase( + std::remove_if(lower_bounds.begin(), lower_bounds.end(), + [&bound, &vranges](const Expr& o) { return CanProve(o - bound <= 0, + vranges); }), + lower_bounds.end()); + // Add + lower_bounds.push_back(bound); + } + + // Sort the vectors and remove duplicates + for (std::vector* bounds : {&upper_bounds, &lower_bounds}) { + std::sort(bounds->begin(), bounds->end(), ExprLess()); + bounds->erase(std::unique(bounds->begin(), bounds->end(), ExprEq()), bounds->end()); + } + + // Bounds which are both lower and upper should go to equal... + std::vector equal; + equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); + std::set_intersection(upper_bounds.begin(), upper_bounds.end(), + lower_bounds.begin(), lower_bounds.end(), + std::back_inserter(equal), ExprLess()); + + // ...and be removed from upper bounds... + std::vector new_upper; + new_upper.reserve(upper_bounds.size() - equal.size()); + std::set_difference(upper_bounds.begin(), upper_bounds.end(), + equal.begin(), equal.end(), + std::back_inserter(new_upper), ExprLess()); + + // ...and from lower bounds. + std::vector new_lower; + new_lower.reserve(lower_bounds.size() - equal.size()); + std::set_difference(lower_bounds.begin(), lower_bounds.end(), + equal.begin(), equal.end(), + std::back_inserter(new_lower), ExprLess()); + + // Write it to the result. + auto& bnds = res.bounds[v.get()]; + bnds.coef = make_const(v.type(), coef_lcm); + bnds.equal = equal; + bnds.lower = new_lower; + bnds.upper = new_upper; + + std::swap(current, new_current); + } + + // Everything that is left goes to res.other_conditions + for (const Expr& e : current) { + Expr e_simp = SuperSimplify(e, vranges); + if (is_const_int(e_simp, 0)) { + // contradiction detected + res.other_conditions = {const_false()}; + return res; + } else if (is_const_int(e_simp, 1)) { + continue; + } else { + res.other_conditions.push_back(e_simp); + } + } + + for (const Expr& e : rest) + res.other_conditions.push_back(e); + + return res; +} + + +// Simplify an iteration domain. +DomainSimplificationResult SimplifyDomain(const Expr& cond, + const Array& axis, + Map vranges, + bool eliminate_div_mod) { + if (eliminate_div_mod) { + auto elim_res = EliminateDivMod(cond, vranges); + + Map new_vranges = elim_res.ranges; + Array new_axis = Concat(axis, elim_res.new_variables); + Expr new_cond = elim_res.expr && All(elim_res.conditions); + + auto res = SimplifyDomain(new_cond, new_axis, new_vranges, false); + + Map new_old_to_new; + for (const Var& v : axis) { + new_old_to_new.Set(v, res.old_to_new[v]); + } + + Map new_new_to_old; + for (const auto& pair : res.new_to_old) { + new_new_to_old.Set(pair.first, Substitute(pair.second, elim_res.substitution)); + } + + res.old_to_new = std::move(new_old_to_new); + res.new_to_old = std::move(new_new_to_old); + + return res; + } + + auto factoratomic_res = FactorOutAtomicFormulas(cond); + std::vector& atomic_formulas = factoratomic_res.atomic_formulas; + Expr rest_of_cond = factoratomic_res.rest; + + // Put rest_of_cond into the vector of atomic formulas so that we don't forget about it. + // Although rest_of_cond is not atomic, the subsequent functions won't complain about it. + atomic_formulas.push_back(rest_of_cond); + + // vars are variables from axis followed by all the other variables from vranges + Array vars = axis; + for (const auto& pair : vranges) { + bool already = false; + for (const Var& v : vars) { + already = already || v.same_as(pair.first); + } + if (!already) { + vars.push_back(pair.first); + } + } + + auto solved_system = SolveSystemOfInequalities(atomic_formulas, vars, vranges); + + DomainSimplificationResult res; + std::unordered_map new_var_intsets; + + // Initialize new_var_intsets with the old var intsets + for (const auto& pair : vranges) { + new_var_intsets[pair.first.get()] = IntSet::range(pair.second); + } + + // We process variables in the reverse direction to start with the most independent one. + // This order is needed to compute new ranges. + for (auto it = axis.rbegin(); it != axis.rend(); ++it) { + const Var& var = *it; + auto& bnd = solved_system.bounds[var.get()]; + // Note that we replace old vars with new ones + bnd = bnd.substitute(res.old_to_new); + if (is_one(bnd.coef) && !bnd.equal.empty()) { + // There is an equation of the form `v == expr`, so this variable can be completely removed. + // Note that we use the 0-th expression because they are ordered by complexity, so it must be + // the simplest one. + res.old_to_new.Set(var, bnd.equal[0]); + } else { + Array lowers = Concat(bnd.equal, bnd.lower); + Array uppers = Concat(bnd.equal, bnd.upper); + + // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the + // pair with the minimal difference between the upper and the lower. + // Note that the bounds are for v*coef, not for v (because we don't want complex expressions + // involving division). + + // The lower bound of the best pair so far + Expr best_lower = vranges[var]->min * bnd.coef; + // The difference between the upper and the lower of the best pair so far + Expr best_diff = (vranges[var]->extent - 1) * bnd.coef; + // The overapproximation of the best difference + Expr best_diff_over = best_diff; + + for (const Expr& low : lowers) { + for (const Expr& upp : uppers) { + Expr diff = SuperSimplify(upp - low, vranges); + // Since diff may depend on some other variables, we compute its overapproximation + Expr diff_over = EvalSet(diff, new_var_intsets).max(); + + if (diff_over.same_as(HalideIR::Internal::Interval::pos_inf)) { + continue; + } + + // If it is provable that the new one is strictly better than the current best one, + // then replace it. Note that we are biased towards earlier pairs which should be simpler. + if (CanProve(diff_over - best_diff_over < 0, vranges)) { + best_lower = low; + best_diff = diff; + best_diff_over = diff_over; + } + } + } + + if (is_const_int(best_diff, 0)) { + // In this case coef*iv = best_lower + // Don't create an itervar, just replace it everywhere with its min + res.old_to_new.Set(var, SuperSimplify(best_lower / bnd.coef, vranges)); + // To assure correctness, we have to add a condition that best_lower can be divided by coef + res.conditions.push_back(SuperSimplify(best_lower % bnd.coef == 0, vranges)); + } else { + std::string suffix = Equal(best_lower, vranges[var]->min * bnd.coef) ? "" : ".shifted"; + Var new_var = var.copy_with_suffix(suffix); + + // We will replace our iv with new_var + shift. + // We use rounding-up division to compute shift. Since we want to use a single formula + // without selects in as many cases as possible, we try to prove conditions manually. + Expr shift; + if (CanProve(best_lower <= 0, vranges)) { + shift = best_lower / bnd.coef; + } else if (CanProve(best_lower > -bnd.coef, vranges)) { + shift = (best_lower + bnd.coef - 1)/bnd.coef; + } else { + shift = Select::make(best_lower <= -bnd.coef, + best_lower / bnd.coef, + (best_lower + bnd.coef - 1)/bnd.coef); + } + shift = SuperSimplify(shift, vranges); + + Expr diff = SuperSimplify(best_diff_over / bnd.coef, vranges); + + if (is_const_int(diff, 0)) { + // Don't create an itervar, just replace it everywhere with its min + res.old_to_new.Set(var, shift); + } else { + res.old_to_new.Set(var, new_var + shift); + // Note that we are substituting old with new, so best_lower contains new var, + // that is we have to substitute new with old in best_lower here + res.new_to_old.Set(new_var, + SuperSimplify(var - Substitute(shift, res.new_to_old), vranges)); + + new_var_intsets[new_var.get()] = IntSet::interval(make_zero(new_var.type()), diff); + + // Add the new var to the resulting axis + auto range = Range(make_zero(new_var.type()), diff + 1); + res.axis.push_back(new_var); + res.ranges.Set(new_var, range); + vranges.Set(new_var, range); + } + } + } + } + + // Add the original conditions (with variables substituted) to the resulting conditions + for (const Expr& old_cond : solved_system.as_conditions()) { + res.conditions.push_back(SuperSimplify(Substitute(old_cond, res.old_to_new), vranges)); + } + + return res; +} + +// Use the condition of a reduction op to simplify its domain (axis) +Expr SimplifyReductionDomain(const Expr& expr, const Map& outer_vranges) { + if (const Reduce* red = expr.as()) { + Map vranges = Merge(outer_vranges, IterVarsToMap(red->axis)); + auto res = SimplifyDomain(red->condition, IterVarsToVars(red->axis), + Merge(outer_vranges, IterVarsToMap(red->axis))); + + Array new_source; + for (const Expr& src : red->source) { + new_source.push_back(Substitute(src, res.old_to_new)); + } + + Array new_axis = IterVarsFromMap(res.axis, res.ranges, kCommReduce); + + // Perform simplification mainly to remove a possibly empty reduction. + return Simplify(Reduce::make(red->combiner, new_source, new_axis, + All(res.conditions), red->value_index)); + } else { + return expr; + } +} + +// Extract the given expr under the given condition as a separate tensor if the volume of the +// extracted tensor will be less than the volume of the outer_axis +Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, + const Array& outer_axis, + const Map& vranges) { + // TODO(sgrechanik-h): We don't use divmod elimination here because of some performance problems + auto res = SimplifyDomain(cond, outer_axis, vranges, false); + + Expr new_expr = SuperSimplify(Substitute(e, res.old_to_new), vranges); + + // Keep only those variables of the new axis which are used in the new_expr + { + Array used_res_axis; + for (const Var& var : res.axis) { + if (ExprUseVar(new_expr, var)) { + used_res_axis.push_back(var); + } + } + + res.axis = std::move(used_res_axis); + } + + // Use the new axis to simplify the new expr, removing redundant inequalities + new_expr = SuperSimplify(new_expr, res.ranges); + + // If the expression does not use vars then it is probably better to keep it inlined + if (res.axis.empty()) { + return new_expr; + } + + // Compute volumes before and after + Expr old_volume = make_const(Int(64), 1); + for (const Var& var : outer_axis) { + old_volume = old_volume * vranges[var]->extent; + } + + Expr new_volume = make_const(Int(64), 1); + for (const Var& var : res.axis) { + new_volume = new_volume * res.ranges[var]->extent; + } + + // if we can prove that the old volume is not greater than the new volume then + // prefer the old expression. + if (CanProve(old_volume <= new_volume, vranges)) { + return e; + } + + Tensor tensor = op::TensorFromExpr(new_expr, IterVarsFromMap(res.axis, res.ranges)); + + Array args; + for (const Var& var : res.axis) { + args.push_back(res.new_to_old[var]); + } + + return Call::make(e.type(), tensor->op->name, args, + Call::CallType::Halide, tensor->op, tensor->value_index); +} + + +class RemoveRedundantInequalitiesMutator : public IRMutator { + public: + explicit RemoveRedundantInequalitiesMutator(Array known) { + for (const Expr& cond : known) { + known_.push_back(SuperSimplify(cond)); + } + } + + virtual Expr Mutate_(const Select* op, const Expr& e) { + bool has_side_effect = HasSideEffect(e); + Expr new_cond = SuperSimplify(Mutate(op->condition)); + if (is_one(new_cond) && !has_side_effect) { + return Mutate(op->true_value); + } else if (is_zero(new_cond) && !has_side_effect) { + return Mutate(op->false_value); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return Select::make(new_cond, new_mutator.Mutate(op->true_value), Mutate(op->false_value)); + } + } + + virtual Expr Mutate_(const Call* op, const Expr& e) { + if (op->name == intrinsic::tvm_if_then_else) { + Expr new_cond = SuperSimplify(Mutate(op->args[0])); + if (is_one(new_cond)) { + return Mutate(op->args[1]); + } else if (is_zero(new_cond)) { + return Mutate(op->args[2]); + } else { + Array new_known = known_; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + // Note that we mutate only the true value with the new mutator + // TODO(sgrechanik-h): Update known conditions for the false value as well + return if_then_else(new_cond, new_mutator.Mutate(op->args[1]), Mutate(op->args[2])); + } + } else { + return IRMutator::Mutate_(op, e); + } + } + + virtual Expr Mutate_(const Reduce* op, const Expr& e) { + Array known_with_axes = known_; + for (const Expr& axis_cond : IterVarsToInequalities(op->axis)) { + known_with_axes.push_back(axis_cond); + } + RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes); + + Expr new_cond = mutator_with_axes.Mutate(op->condition); + + Array new_known = known_with_axes; + for (const Expr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { + new_known.push_back(atomic); + } + RemoveRedundantInequalitiesMutator new_mutator(new_known); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + return Reduce::make(op->combiner, new_source, op->axis, new_cond, op->value_index); + } + + virtual Expr Mutate_(const EQ* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const NE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const LE* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GT* op, const Expr& e) { return MutateAtomic_(e); } + virtual Expr Mutate_(const GE* op, const Expr& e) { return MutateAtomic_(e); } + + virtual Expr Mutate_(const And* op, const Expr& e) { + return Mutate(op->a) && Mutate(op->b); + } + + private: + Expr MutateAtomic_(const Expr& e) { + Expr simplified = SuperSimplify(e); + for (const Expr& other : known_) { + if (Equal(simplified, other)) { + return const_true(); + } + } + return simplified; + } + + Array known_; +}; + +// Propagate information from conditions and remove redundant inequalities +// TODO(sgrechanik-h): This should be merged into standard simplifiers +Expr RemoveRedundantInequalities(const Expr& expr, const Array& known) { + return RemoveRedundantInequalitiesMutator(known).Mutate(expr); +} + +// Extract from cond an implication of cond not containing vars +std::pair ImplicationNotContainingVars( + const Expr& cond, const std::unordered_set& vars) { + CHECK(cond.type().is_bool()) << "The type of cond must be bool"; + // TODO(sgrechanik-h): not + if (const And* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {pair_a.first && pair_b.first, + pair_a.second && pair_b.second}; + } else if (const Or* op = cond.as()) { + auto pair_a = ImplicationNotContainingVars(op->a, vars); + auto pair_b = ImplicationNotContainingVars(op->b, vars); + return {Or::make(pair_a.first, pair_b.first), cond}; + } else if (!ExprUseVar(cond, vars)) { + return {cond, const_true()}; + } else { + return {const_true(), cond}; + } +} + +// Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out +// (in)equalities which do not depend on the reduction variables. +std::pair LiftConditionsThroughReduction(const Expr& cond, + const Array& red_axis, + const Array& outer_axis) { + // Factor out atomics so that we can consider this as a system of inequalities + auto factoratomic_res = FactorOutAtomicFormulas(cond); + Array atomics = factoratomic_res.atomic_formulas; + const Expr& rest = factoratomic_res.rest; + + Array allvars; + for (const IterVar& v : red_axis) { + allvars.push_back(v->var); + } + for (const IterVar& v : outer_axis) { + allvars.push_back(v->var); + } + + auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis)); + // start from reduction vars, so that input vars don't depend on them + atomics = SolveSystemOfInequalities(atomics, allvars, vranges).as_conditions(); + + // Append the rest part + Expr rewritten_cond = All(atomics) && rest; + + std::unordered_set vset; + for (const IterVar& v : red_axis) { + vset.insert(v->var.get()); + } + + // The outer (first) condition does not contain reduction vars, + // the inner (second) condition is everything else + return ImplicationNotContainingVars(rewritten_cond, vset); +} + +class ExtractReductionsMutator : public IRMutator { + public: + explicit ExtractReductionsMutator(Map vranges, std::string name = "extracted") + : vranges_(std::move(vranges)), name_(std::move(name)) {} + + Expr Mutate_(const Reduce* op, const Expr& e) { + ExtractReductionsMutator new_mutator(Merge(vranges_, IterVarsToMap(op->axis)), name_); + + Array new_source; + for (const Expr& src : op->source) { + new_source.push_back(new_mutator.Mutate(src)); + } + + Expr new_reduce = + Reduce::make(op->combiner, new_source, op->axis, op->condition, op->value_index); + Array vars = ExprFreeVars(new_reduce); + + auto newaxis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_)); + Array new_axis = newaxis_vmap_pair.first; + new_reduce = SuperSimplify(Substitute(new_reduce, newaxis_vmap_pair.second), + IterVarsToMap(new_axis)); + + Tensor tensor = op::TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_); + + Array args; + for (const Var& v : vars) { + args.push_back(v); + } + + return Call::make(e.type(), tensor->op->name, args, + Call::CallType::Halide, tensor->op, tensor->value_index); + } + + private: + Map vranges_; + std::string name_; + std::string tag_; + Map attrs_; +}; + +// Extract reductions as separate tensors. +Expr ExtractReductions(const Expr& expr, const Map& vranges) { + return ExtractReductionsMutator(vranges).Mutate(expr); +} + +Expr ExtractNonTopReductions(const Expr& expr, const Map& vranges) { + if (const Reduce* red = expr.as()) { + Map new_vranges = Merge(vranges, IterVarsToMap(red->axis)); + Array new_source; + for (const Expr& src : red->source) { + new_source.push_back(ExtractReductions(src, new_vranges)); + } + Expr new_condition = ExtractReductions(red->condition, new_vranges); + + return Reduce::make(red->combiner, new_source, red->axis, + new_condition, red->value_index); + } else { + return ExtractReductions(expr, vranges); + } +} + +Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, const Array& axis) { + Expr result; + + if (const Reduce* red = expr.as()) { + // TODO(sgrechanik-h): There are some other operations which behave like sum + bool is_sum = IsSumCombiner(red->combiner); + if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index)) { + Expr new_red = expr; + + // Here we simplify the reduction + { + Expr cond = red->condition; + Array source = red->source; + + // If it is a summation then we can lift nonzeroness conditions from the source + // and add them to the reduction conditions + if (is_sum) { + auto nz = NonzeronessCondition(red->source[red->value_index]); + cond = nz.cond && cond; + source.Set(0, nz.value); + } + + new_red = Reduce::make(red->combiner, source, red->axis, cond, red->value_index); + new_red = SimplifyReductionDomain(new_red, IterVarsToMap(axis)); + red = new_red.as(); + + // If the reduction disappears completely then transform the result as a non-reduction + if (!red) { + return OptimizeAndLiftNonzeronessConditionsImpl(new_red, axis); + } + } + + Expr new_outer_cond, new_reduce_cond; + Array new_source = red->source; + + // Partially lift conditions from the reduce condition + std::tie(new_outer_cond, new_reduce_cond) = + LiftConditionsThroughReduction(red->condition, red->axis, axis); + + // If it's not sum then we haven't yet lifted nonzeroness cond from the source + if (!is_sum) { + Expr outer_nz_cond, nz_cond, nz_source; + auto nz = NonzeronessCondition(red->source[red->value_index]); + // Append conditions from the reduction + nz_cond = new_reduce_cond && nz.cond; + nz_source = nz.value; + std::tie(outer_nz_cond, nz_cond) = + LiftConditionsThroughReduction(nz_cond, red->axis, axis); + new_outer_cond = new_outer_cond && outer_nz_cond; + new_source.Set(red->value_index, SelectElseZero(nz_cond, nz_source)); + } + + Expr new_reduce = Reduce::make(red->combiner, new_source, red->axis, + new_reduce_cond, red->value_index); + new_reduce = ExtractAsTensorMaybe(new_reduce, new_outer_cond, + IterVarsToVars(axis), IterVarsToMap(axis)); + result = SelectElseZero(new_outer_cond, new_reduce); + } else { + return SimplifyReductionDomain(expr, IterVarsToMap(axis)); + } + } else { + auto nz = NonzeronessCondition(expr); + Expr new_expr = ExtractAsTensorMaybe(nz.value, nz.cond, + IterVarsToVars(axis), IterVarsToMap(axis)); + result = SelectElseZero(nz.cond, new_expr); + } + + // Note that RemoveRedundantInequalities can sometimes propagate equalities which + // other simplifiers cannot, like (i % 3) == 0. + Array axis_conds = IterVarsToInequalities(axis); + result = RemoveRedundantInequalities(result, axis_conds); + + // Sometimes ExtractAsTensorMaybe doesn't perform extraction, so there may be some non-top + // reductions left, take care of them + Map vrange = IterVarsToMap(axis); + return SuperSimplify(ExtractReductions(result, vrange), vrange); +} + +Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor) { + return op::TransformBody(tensor, OptimizeAndLiftNonzeronessConditionsImpl); +} + +TVM_REGISTER_API("ir_pass.IsSumCombiner") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IsSumCombiner(args[0]); + }); + +TVM_REGISTER_API("ir_pass.CanFactorZeroFromCombiner") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CanFactorZeroFromCombiner(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.LiftNonzeronessCondition") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LiftNonzeronessCondition(args[0]); + }); + +TVM_REGISTER_API("ir_pass.InlineTailCall") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = InlineTailCall(args[0]); + }); + +TVM_REGISTER_API("ir_pass.InlineTensors") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args[0].IsNodeType()) { + Expr e = args[0]; + if (args.size() == 1) { + *ret = InlineTensors(e); + } else if (args.size() == 2) { + *ret = InlineTensors(e, args[1]); + } else if (args.size() >= 3) { + *ret = InlineTensors(e, args[1], args[2]); + } + } else if (args[0].IsNodeType()) { + Tensor t = args[0]; + if (args.size() == 1) { + *ret = InlineTensors(t); + } else if (args.size() == 2) { + *ret = InlineTensors(t, args[1]); + } else if (args.size() >= 3) { + *ret = InlineTensors(t, args[1], args[2]); + } + } + }); + +TVM_REGISTER_API("ir_pass.SolveSystemOfInequalities") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SolveSystemOfInequalities(args[0], args[1], args[2]).as_conditions(); + }); + +TVM_REGISTER_API("ir_pass.SimplifyDomain") +.set_body([](TVMArgs args, TVMRetValue *ret) { + auto res = SimplifyDomain(args[0], args[1], args[2]); + Array axis = IterVarsFromMap(res.axis, res.ranges); + *ret = Array({All(res.conditions), axis, res.old_to_new, res.new_to_old}); + }); + +TVM_REGISTER_API("ir_pass.SimplifyReductionDomain") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SimplifyReductionDomain(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.ExtractAsTensorMaybe") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractAsTensorMaybe(args[0], args[1], args[2], args[3]); + }); + +TVM_REGISTER_API("ir_pass.ExtractReductions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractReductions(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.ExtractNonTopReductions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExtractNonTopReductions(args[0], args[1]); + }); + +TVM_REGISTER_API("ir_pass.OptimizeAndLiftNonzeronessConditions") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = OptimizeAndLiftNonzeronessConditions(args[0]); + }); + +} // namespace ir +} // namespace tvm diff --git a/src/pass/zero_elimination.h b/src/pass/zero_elimination.h new file mode 100644 index 000000000000..26b2acc15b7a --- /dev/null +++ b/src/pass/zero_elimination.h @@ -0,0 +1,243 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file zero_elimination.h + * \brief Transform tensors in such a way as to eliminate summation over zeros. + */ +#ifndef TVM_PASS_ZERO_ELIMINATION_H_ +#define TVM_PASS_ZERO_ELIMINATION_H_ + +#include +#include + +#include + +namespace tvm { +namespace ir { + +/*! + * \brief Clone the reduction by cloning its iteration variables. + */ +Expr CloneReduction(const Expr& expr); + +/*! + * \brief Check if the given combiner represents summation. + */ +EXPORT bool IsSumCombiner(const CommReducer& combiner); + +/*! + * \brief Check if zero may be factored out of a reduction with this combiner when it is in + * the \p value_index position. + * + * For example, if the combiner works on tuples of two elements and `value_index = 1`, + * check that `(a, 0) combine (b, 0) = (c, 0)` for any a, b and some c. + * Note that all combiners generated by autodiff have this property. + */ +EXPORT bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index); + +/*! + * \brief Transform the expression into `c ? e : 0`, that is lift the condition of being + * possible to be non-zero to the top level. + */ +EXPORT Expr LiftNonzeronessCondition(const Expr& expr); + +/*! + * \brief If the body of the tensor consists of a single tensor call (indexing) expression, + * inline it. + */ +EXPORT Tensor InlineTailCall(const Tensor& tensor); + +/*! + * \brief Inline tensors recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param expr The expression to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + */ +EXPORT Expr InlineTensors(const Expr& expr, + const Array& inlineable = Array(), + bool inline_reductions = false); + +/*! + * \brief Inline tensors recursively. + * + * This function will inline tensors recursively until it reaches a tensor which is impossible to + * inline (a reduction if \p inline_reductions is false, a non-compute tensor, a tensor which is + * not from \p inlineable). It won't descend into non-inlinable tensors' bodies. + * + * \param tensor The tensor whose body to transform. + * \param inlineable A list of tensors which are allowed to be inlined. If empty, try + * to inline all tensors. + * \param inline_reductions Whether to inline reductions (this may result in top-level reduction + * nodes). + */ +EXPORT Tensor InlineTensors(const Tensor& tensor, + const Array& inlineable = Array(), + bool inline_reductions = false); + + +/*! + * \brief A struct representing a set of inequalities describing bounds of a variable. + * + * Given a variable x, this struct represents the following (in)equalities: + * - `coef*x >= low` for each `low` in `lower` + * - `coef*x == eq` for each `eq` in `equal` + * - `coef*x <= upp` for each `upp` in `upper` + * + * Note that every array is supposed to be sorted in the order of increasing expression + * complexity. + */ +struct VarBounds { + Expr coef; + Array lower; + Array equal; + Array upper; + + /*! + * \brief Perform substitution on all components of the struct. + */ + VarBounds substitute(const Map& subst) const; +}; + +/*! + * \brief A struct representing a system of inequalities resulted from Fourier-Motzkin elimination. + */ +struct SolveSystemOfInequalitiesResult { + Array variables; + std::unordered_map bounds; + Array other_conditions; + + /*! + * \brief Combine the information into an array of (in)equalities. + */ + Array as_conditions() const; +}; + +/*! + * \brief Rewrite the system of inequalities using Fourier-Motzkin elimination. + * + * This function takes an array of (in)equalities and an array of variables, and essentially + * rewrites the (in)equalities into an array of (in)equalities of the following form: + * + * x0 >= f0(x1, x2, ..., xn) + * x0 <= g0(x1, x2, ..., xn) + * x1 >= f1(x2, ..., xn) + * x1 <= g1(x2, ..., xn) + * ... + * xn >= fn() // just a constant + * xn <= gn() // just a constant + * + * This array is represented in a more structural way using SolveSystemOfInequalitiesResult. + * + * Note that the algorithm is extremely slow, it is super-exponential, so please provide variable + * ranges to aid the removal of redundant inequalities. + * + * \param inequalities The original (in)equalities. + * \param variables The variables x0, ..., xn + * \param vranges A map from variables to the corresponding value ranges. Extremely important for + * efficiency. + */ +EXPORT SolveSystemOfInequalitiesResult SolveSystemOfInequalities( + const Array& inequalities, const Array& variables, const Map& vranges); + +/*! + * \brief A struct representing a result of domain simplification. It is basically + * a new array of variables, the information about their ranges, and a new condition together with + * substitutions from the old variables to the new ones and from the new ones to the old ones. + */ +struct DomainSimplificationResult { + Array conditions; + Array axis; + Map ranges; + Map old_to_new; + Map new_to_old; +}; + +/*! + * \brief Simplify an iteration domain. + * + * An iteration domain is basically an array of variables and a condition. The function will do the + * following: + * - Replace div and mod operations with new variables (optional). + * - Extract (in)equalities from the condition. + * - Perform Fourier-Motzkin elimination. + * - Shear the domain of iteration (e.g. if `y <= x <= y + 2` then x will be replaced with `y + d` + * where `d` is a new variable such that `0 <= d <= 2`). + * - Remove redundant variables. + * - Infer new variable ranges (hopefully more precise). + * + * \param cond The condition of the original domain. + * \param axis The variables of the original domain. + * \param vranges A map from variables (both domain and outer) to their value ranges. + * \param eliminate_div_mod Whether to eliminate div and mod by introducing new variables. + */ +EXPORT DomainSimplificationResult SimplifyDomain(const Expr& cond, + const Array& axis, + Map vranges, + bool eliminate_div_mod = true); + + +/*! + * \brief Simplify the iteration domain of a reduction expression using SimplifyDomain. + */ +EXPORT Expr SimplifyReductionDomain(const Expr& expr, const Map& outer_vranges); + +/*! + * \brief Extract the given expression under the given condition as a separate tensor if the volume + * of the extracted tensor will be less than the volume of the \p outer_axis. + * + * \param expr The expression to extract. + * \param cond A condition which is assumed to be true. + * \param outer_axis Some variables, usually input variables of the enclosing tensor. + * \param vranges Information about ranges of variables. + * \return Either a call to an extracted tensor or the original expression. + */ +EXPORT Expr ExtractAsTensorMaybe(const Expr& expr, const Expr& cond, + const Array& outer_axis, + const Map& vranges); + +/*! + * \brief Extract reductions as separate tensors. This may be needed when non-top-level reductions + * are created. + * + * \param expr The expression from which to extract reductions. + * \param vranges Information about ranges of variables. + * \return An expression without non-top-level reductions. + */ +EXPORT Expr ExtractReductions(const Expr& expr, const Map& vranges); + +/*! + * \brief Extract reductions as separate tensors, but if the expr itself is a reduction, leave it + * intact. + * + * \param expr The expression from which to extract reductions. + * \param vranges Information about ranges of variables. + * \return An expression without non-top-level reductions. + */ +EXPORT Expr ExtractNonTopReductions(const Expr& expr, const Map& vranges); + +/*! + * \brief Perform lifting of conditions of being possible to be non-zero together with + * applying some transformations like simplifying the reduction domain. Works only with + * this particular tensor's body, i.e. doesn't perform inlining. + */ +EXPORT Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor); + +/*! + * \brief Pretty print the tensor with all its dependencies. + */ +EXPORT std::string PrintTensorRecursively(const Tensor& tensor); + +/*! + * \brief Pretty print the tensors with all their dependencies. + */ +EXPORT std::string PrintTensorsRecursively(const Array& tensor); + +} // namespace ir +} // namespace tvm +#endif // TVM_PASS_ZERO_ELIMINATION_H_ diff --git a/tests/python/unittest/test_pass_autodiff.py b/tests/python/unittest/test_pass_autodiff.py new file mode 100644 index 000000000000..6c971215e796 --- /dev/null +++ b/tests/python/unittest/test_pass_autodiff.py @@ -0,0 +1,414 @@ +# This example demonstrates Automatic Differentiation for TVM basic operations and TOPI primitives. +# See `test_autodiff()` and `test_nn_autodiff()` for details. + +import tvm +import topi +import numpy as np +from tvm.testing import check_numerical_grads, estimate_performance, PerformanceEstimate +import time +import inspect +import sys + +# Whether to dump the generated code +verbose = False + +def get_shape(tensor): + return [tvm.ir_pass.Simplify(s).value for s in tensor.shape] + +def check_equivalence(outputs1, outputs2, inputs, in_range=(-10, 10), iters=10): + outputs1 = list(outputs1) + outputs2 = list(outputs2) + sched1 = tvm.create_schedule([o.op for o in outputs1]) + mout1 = tvm.build(sched1, outputs1 + inputs) + + sched2 = tvm.create_schedule([o.op for o in outputs2]) + mout2 = tvm.build(sched2, outputs2 + inputs) + + arguments1 = [tvm.nd.empty(get_shape(t), t.dtype) for t in outputs1 + inputs] + arguments2 = [tvm.nd.empty(get_shape(t), t.dtype) for t in outputs1 + inputs] + + for i in range(iters): + arguments1 = [] + arguments2 = [] + for a in outputs1 + inputs: + val = np.random.uniform(in_range[0], in_range[1], size=get_shape(a)).astype(a.dtype) + arguments1.append(tvm.nd.array(val)) + arguments2.append(tvm.nd.array(val)) + mout1(*arguments1) + mout2(*arguments2) + + for j, _ in enumerate(outputs1): + tvm.testing.assert_allclose(arguments1[j].asnumpy(), arguments2[j].asnumpy()) + +def test_grad(out, inputs, args=[], in_range=(-10,10), perf=None): + line = inspect.getframeinfo(inspect.stack()[1][0]).lineno + + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + if verbose: + print("\n" + 80*"=" + "\n") + print("Testing gradients, line {}\n".format(line)) + print("Original tensors:\n") + print(tvm.PrintTensorRecursively(out)) + print() + + sout = tvm.create_schedule(out.op) + mout = tvm.build(sout, [out] + inputs + args) + + ones = topi.full_like(out, 1.0) + + grads = list(tvm.differentiate(out, inputs, ones)) + + if verbose: + print("Gradients:\n") + print(tvm.PrintTensorsRecursively(grads)) + print() + + grads_sched = tvm.create_schedule([g.op for g in grads]) + mgrad = tvm.build(grads_sched, grads + inputs + args) + + lowered = tvm.lower(grads_sched, grads + inputs + args, simple_mode=True) + + if verbose: + print("Lowered gradients:\n") + print(lowered) + print() + + est = estimate_performance(grads) + est_lowered = estimate_performance(lowered) + + if verbose: + print("Note: performance tuples are (iterations, multiplications, memory)") + print("Expected performance of grads: {}".format(perf)) + print("Estimated performance of grads: {}".format(est.as_tuple())) + print("Estimated performance of lowered grads: {}".format(est_lowered.as_tuple())) + print() + + if est_lowered.memory > est.memory: + print("WARNING: Line {}: The estimated memory consumption increased after lowering, " + "this may indicate that tensor bounds have been expanded too much".format(line)) + print("before: {} after: {}".format(est, est_lowered)) + + (iters, mults, mem) = est.as_tuple() + if perf is None or isinstance(perf, str): + print("WARNING: Line {}: No performance information, you may set it to {}" + .format(line, est.as_tuple())) + if isinstance(perf, str): + print("0,/{!r}/{{s/{!r}/{}/}}".format(perf, perf, (iters, mults, mem))) + elif perf != (iters, mults, mem): + (ref_iters, ref_mults, ref_mem) = perf + ref_est = PerformanceEstimate(*perf) + + if est <= ref_est: + print("WARNING: Line {}: Estimated performance {} is better than {}. " + "Use this with sed:" + .format(line, est.as_tuple(), ref_est.as_tuple())) + print("0,/{}/{{s/{}/{}/}}".format(perf, perf, (iters, mults, mem))) + elif est >= ref_est: + print("WARNING: Line {}: Estimated performance {} IS WORSE THAN {}" + .format(line, est.as_tuple(), ref_est.as_tuple())) + else: + print("WARNING: Line {}: Estimated performance {} does not match {}" + .format(line, est.as_tuple(), ref_est.as_tuple())) + + EST_RTOL = 1.5 + if iters > ref_iters*EST_RTOL or mults > ref_mults*EST_RTOL or mem > ref_mem*EST_RTOL: + raise AssertionError("Line {}: Some of the estimated performance metrics are much " + "worse than the reference ones (by {}): estimated {}, expected {}" + .format(line, EST_RTOL, est.as_tuple(), ref_est.as_tuple())) + + def fun(*arguments): + arrays = [tvm.nd.empty(get_shape(out), out.dtype)] + [tvm.nd.array(a) for a in arguments] + mout(*arrays) + return arrays[0].asnumpy().sum() + + arg_vals = [tvm.nd.array(np.random.uniform(in_range[0], in_range[1], + size=get_shape(a)).astype(a.dtype)) + for a in inputs + args] + + g_arg_vals = [tvm.nd.empty(get_shape(i), g.dtype) for i, g in zip(inputs, grads)] + arg_vals + mgrad(*g_arg_vals) + g_res = [g_arg_vals[g].asnumpy() for g, _ in enumerate(grads)] + + check_numerical_grads(fun, [a.asnumpy() for a in arg_vals], g_res) + +def test_differentiate_function(): + x = tvm.placeholder((32, 3, 28, 28), name='x') + + w = tvm.placeholder((10, 3, 3, 3), name='w') + t1 = topi.nn.conv2d(x, w, 1, 0, 1) + + t2 = topi.nn.flatten(t1) + t3 = topi.sum(t2) + + [dx1, dw1] = tvm.differentiate(t3, [x, w]) + [dx2, dw2] = tvm.differentiate(t2, [x, w], topi.full_like(t2, 1.0)) + + check_equivalence([dx1, dw1], [dx2, dw2], [x, w]) + + def mydiff(out, inp, head): + return tvm.compute(inp.shape, + lambda ax0, ax1, ax2, ax3: head[ax0, ax3 + ax2*26 + ax1*676]) + + res = tvm.differentiate(t3, [x, w], manual={(t2, t1): mydiff}) + check_equivalence(res.result, [dx1, dw1], [x, w]) + + res = tvm.differentiate(t3, [x, w], manual={(t2, None): mydiff}) + check_equivalence(res.result, [dx1, dw1], [x, w]) + + res = tvm.differentiate(t3, [x, w], manual={(None, t1): mydiff}) + check_equivalence(res.result, [dx1, dw1], [x, w]) + +# Test some simple expressions +def test_autodiff(): + x = tvm.var("x", dtype='float32') + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + A0 = tvm.placeholder((10, 10), name='A0') + A1 = tvm.placeholder((10, 10), name='A1') + + B = tvm.compute((10, 10), lambda i, j: A0[i, j] + A0[j, i], name='B') + test_grad(B, A0, perf=(10100, 10000, 200)) + + B = tvm.compute((10, 10), lambda i, j: A0[i, j] + tvm.exp(A0[j, i]), name='B') + test_grad(B, A0, perf=(10100, 20000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.log(tvm.abs(A0[i, j] + tvm.exp(A0[j, i]))), name='B') + test_grad(B, A0, perf=(10100, 70000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.sigmoid(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + test_grad(B, A0, perf=(10100, 120000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.tanh(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + test_grad(B, A0, perf=(10100, 120000, 200)) + + B = tvm.compute((10, 10), lambda i, j: A0[i, j] * A0[j, i], name='B') + test_grad(B, A0, perf=(10100, 10000, 200)) + + # TODO: This one needs transforming Sum(a + b) -> Sum(a) + Sum(b) + B = tvm.compute((10,), lambda i: tvm.sum(A0[i, k]*A0[k, i], axis=k), name='B') + test_grad(B, A0, perf=(11010, 1000, 1110)) + + B = tvm.compute((10, 10), lambda i, j: tvm.sum(A0[i, k]*A0[k, i] + 5, axis=k), name='B') + test_grad(B, A0, perf=(20100, 10000, 1200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.max(A0[i, k]*A0[k, j] + 5, axis=k), name='B') + test_grad(B, A0, perf=(110100, 310000, 20200)) + + B = tvm.compute((10, 10), lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]), name='B') + test_grad(B, A0, [A1], perf=(10100, 10000, 200)) + + B = tvm.compute((10, 10), lambda i, j: tvm.sum(A0[k, k] - A0[tvm.min(j + k, 9), j]*A0[i, k], + axis=k), + name='B') + test_grad(B, A0, perf=(110100, 10000, 10200)) + + def fcombine(x, y): + return x*y + + def fidentity(t0): + return tvm.const(1, t0) + + prod = tvm.comm_reducer(fcombine, fidentity, name='prod') + B = tvm.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name='B') + test_grad(B, A0, perf=(20100, 40000, 2200)) + + X = tvm.placeholder((10,), name='X') + A = tvm.compute((10,), lambda i: X[i] + X[9 - i]) + B = tvm.compute((10,), lambda i: X[i] * X[9 - i]) + Y = topi.tensordot(A, B, 1) + test_grad(Y, X, perf=(251, 230, 71)) + +def test_topi_autodiff(): + X = tvm.placeholder((1, 2, 4, 4), name='X') + W = tvm.placeholder((5, 2, 3, 3), name='W') + W1 = tvm.placeholder((2, 5, 3, 3), name='W1') + W2 = tvm.placeholder((1,), name='W1') + + R = topi.nn.conv2d(X, W, 1, 1, 1) + test_grad(R, [X, W], perf=(3410, 2880, 652)) + + R1 = topi.nn.conv2d(topi.nn.relu(R), W1, 1, 0, 1) + test_grad(R1, [X, W, W1], perf=(6198, 5320, 1250)) + + R = topi.broadcast_to(W2, (5, 2, 3, 3)) + test_grad(R, [W2], perf=(180, 0, 91)) + + R = topi.nn.conv2d(X, topi.broadcast_to(W2, (5, 2, 3, 3)), 1, 1, 1) + test_grad(R, [X, W2], perf=(3590, 2880, 743)) + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'avg') + test_grad(R, X, perf=(40, 224, 40)) + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') + test_grad(R, X, perf=(168, 1248, 104)) + + X = tvm.placeholder((1, 2, 5, 5), name='X') + R = topi.reshape(X, (1, 32)) + test_grad(R, [X], perf=(82, 1200, 82)) + + X = tvm.placeholder((1, 2, 5, 5), name='X') + W = tvm.placeholder((2, 2, 3, 3), name='W') + + S = topi.reshape(X, (1, 50)) + test_grad(S, [X], perf=(100, 700, 100)) + + R = X + topi.nn.conv2d(X + topi.nn.conv2d(X, W, 1, 1, 1), W, 1, 1, 1) + test_grad(R, [X, W], perf=(6854, 5400, 1726)) + + S = topi.nn.softmax(topi.reshape(R, (1, 50))) + test_grad(S, [X, W], perf=(10956, 14201, 2333)) + + S = topi.sigmoid(topi.reshape(R, (1, 50))) + test_grad(S, [X, W], perf=(8004, 8350, 2026)) + + S = topi.tanh(topi.reshape(R, (1, 50))) + test_grad(S, [X, W], perf=(8004, 8350, 2026)) + + S = topi.nn.log_softmax(topi.reshape(R, (1, 50))) + test_grad(S, [X, W], perf=(10906, 13601, 2283)) + test_grad(S, [W], [X], perf=(8920, 11101, 1997)) + + X = tvm.placeholder((1, 2, 3, 5), name='X') + Y = tvm.placeholder((1, 2, 7, 5), name='Y') + S = topi.concatenate((X, Y), 2) + test_grad(S, [X, Y], perf=(100, 0, 100)) + + X = tvm.placeholder((1, 2, 6, 5), name='X') + (S, R) = topi.split(X, 2, 2) + test_grad(S, [X], perf=(120, 0, 120)) + test_grad(R, [X], perf=(120, 0, 120)) + R1 = topi.concatenate((S, R), 2) + test_grad(R1, [X], perf=(300, 0, 300)) + R2 = topi.concatenate((R, S), 2) + test_grad(R2, [X], perf=(300, 0, 300)) + +def test_stride_dilation(): + X = tvm.placeholder((1, 2, 10, 10), name='X') + + W = tvm.placeholder((2, 2, 1, 1), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + test_grad(Y, [X, W], perf=(1404, 800, 808)) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + test_grad(Y, [X, W], perf=(928, 1572, 670)) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + test_grad(Y, [X, W], perf=(932, 1728, 672)) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + test_grad(Y, [X, W], perf=(1404, 800, 808)) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + test_grad(Y, [X, W], perf=(928, 1572, 670)) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + test_grad(Y, [X, W], perf=(932, 1728, 672)) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + test_grad(Y, [X, W], perf=(1404, 800, 808)) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + test_grad(Y, [X, W], perf=(928, 1572, 670)) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + test_grad(Y, [X, W], perf=(932, 1728, 672)) + + W = tvm.placeholder((2, 2, 2, 2), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + test_grad(Y, [X, W], perf=(3922, 2896, 1242)) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + test_grad(Y, [X, W], perf=(1650, 2800, 1066)) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + test_grad(Y, [X, W], perf=(1146, 2880, 890)) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + test_grad(Y, [X, W], perf=(3500, 19720, 1092)) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + test_grad(Y, [X, W], perf=(3408, 88848, 2034)) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + test_grad(Y, [X, W], perf=(2254, 65232, 992)) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + test_grad(Y, [X, W], perf=(3138, 17696, 970)) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + test_grad(Y, [X, W], perf=(3816, 82368, 2176)) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + test_grad(Y, [X, W], perf=(3834, 104432, 2306)) + + W = tvm.placeholder((2, 2, 3, 3), name='W') + + Y = topi.nn.conv2d(X, W, 1, 0, 1) + test_grad(Y, [X, W], perf=(7420, 5904, 1752)) + Y = topi.nn.conv2d(X, W, 2, 0, 1) + test_grad(Y, [X, W], perf=(3888, 58592, 2214)) + Y = topi.nn.conv2d(X, W, 3, 0, 1) + test_grad(Y, [X, W], perf=(1552, 2268, 1102)) + Y = topi.nn.conv2d(X, W, 1, 0, 2) + test_grad(Y, [X, W], perf=(5916, 42392, 1256)) + Y = topi.nn.conv2d(X, W, 2, 0, 2) + test_grad(Y, [X, W], perf=(6736, 21784, 3694)) + Y = topi.nn.conv2d(X, W, 3, 0, 2) + test_grad(Y, [X, W], perf=(2672, 146096, 1668)) + Y = topi.nn.conv2d(X, W, 1, 0, 3) + test_grad(Y, [X, W], perf=(2896, 89152, 956)) + Y = topi.nn.conv2d(X, W, 2, 0, 3) + test_grad(Y, [X, W], perf=(2280, 12856, 1992)) + Y = topi.nn.conv2d(X, W, 3, 0, 3) + test_grad(Y, [X, W], perf=(2224, 12032, 716)) + + Y = topi.nn.pool(X, [1, 1], [1, 1], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(200, 0, 200)) + Y = topi.nn.pool(X, [1, 1], [2, 2], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(412, 1124, 412)) + Y = topi.nn.pool(X, [1, 1], [3, 3], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(232, 1200, 232)) + Y = topi.nn.pool(X, [2, 2], [1, 1], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(4162, 7200, 1962)) + Y = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(1050, 7800, 650)) + Y = topi.nn.pool(X, [2, 2], [3, 3], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(858, 6304, 602)) + Y = topi.nn.pool(X, [3, 3], [1, 1], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(18128, 34200, 3928)) + Y = topi.nn.pool(X, [3, 3], [2, 2], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(6712, 131312, 1690)) + Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], 'max') + test_grad(Y, [X], perf=(1838, 12950, 704)) + +def test_some_conv2d_net(): + batch_size = 1 + num_classes = 10 + + features = 4 + dense_units = 16 + + x = tvm.placeholder((batch_size, 28, 14, 1)) + y = tvm.placeholder((batch_size, num_classes)) + + w1 = tvm.placeholder((features, 1, 3, 5)) + b1 = tvm.placeholder((features,)) + w2 = tvm.placeholder((features, features, 3, 5)) + b2 = tvm.placeholder((features,)) + b3 = tvm.placeholder((dense_units,)) + w4 = tvm.placeholder((num_classes, dense_units)) + b4 = tvm.placeholder((num_classes,)) + + t = topi.transpose(x, [0, 3, 1, 2]) + t = topi.nn.relu(topi.nn.conv2d(t, w1, 1, 0, 1) + topi.reshape(b1, (1, features, 1, 1))) + t = topi.nn.relu(topi.nn.conv2d(t, w2, 1, 0, 1) + topi.reshape(b2, (1, features, 1, 1))) + t = topi.nn.pool(t, [2, 2], [2, 2], [0, 0, 0, 0], 'avg') + t = topi.transpose(t, [0, 2, 3, 1]) + t = topi.nn.flatten(t) + w3 = tvm.placeholder((dense_units, get_shape(t)[1])) + t = topi.nn.relu(topi.nn.dense(t, w3, b3)) + t = topi.nn.dense(t, w4, b4) + + t = - topi.sum(y * topi.nn.log_softmax(t)) / batch_size + + weights = [w1, b1, w2, b2, w3, b3, w4, b4] + + test_grad(t, weights, [x, y], in_range=(-1.0, 1.0), perf=(194865, 179089, 28194)) + +if __name__ == "__main__": + if "-v" in sys.argv: + verbose = True + + # test_differentiate_function() + test_autodiff() + test_topi_autodiff() + test_stride_dilation() + test_some_conv2d_net() diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py new file mode 100644 index 000000000000..8e4bd59394b5 --- /dev/null +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -0,0 +1,461 @@ +import random +import sys +import numpy as np +import tvm +from tvm import comm_reducer +from tvm.testing import estimate_performance +from tvm.ir_pass import Simplify, Equal, LiftNonzeronessCondition, IsSumCombiner, \ + CanFactorZeroFromCombiner, InlineTailCall, InlineTensors, SolveSystemOfInequalities, \ + SimplifyDomain, SimplifyReductionDomain, ExtractAsTensorMaybe, ExtractReductions, \ + ExtractNonTopReductions, OptimizeAndLiftNonzeronessConditions + +def get_shape(tensor): + return [s.value for s in tensor.shape] + +def check_eq(t1, t2, args): + s1 = tvm.create_schedule(t1.op) + m1 = tvm.build(s1, [t1] + args) + + s2 = tvm.create_schedule(t2.op) + m2 = tvm.build(s2, [t2] + args) + + for _ in range(5): + arg_vals = [tvm.ndarray.array(np.random.uniform(-10, 10, size=get_shape(a)) + .astype(a.dtype)) + for a in [t1] + args] + m1(*arg_vals) + res1 = arg_vals[0].asnumpy() + m2(*arg_vals) + res2 = arg_vals[0].asnumpy() + + np.testing.assert_allclose(res1, res2, atol=1e-3, rtol=1e-2) + +def check_symeq(expr1, expr2): + expr1 = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr1)) + expr2 = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr2)) + + if tvm.ir_pass.Equal(expr1, expr2): + return + + diff = tvm.ir_pass.Simplify(tvm.ir_pass.CanonicalSimplify(expr1 - expr2)) + if not Equal(diff, tvm.const(0, expr1.dtype)): + raise AssertionError("Expressions {} and {} are not equal, their diff is {}" + .format(expr1, expr2, diff)) + +def compute(shape, fcompute): + """Like tvm.compute but automatically extracts reductions.""" + return tvm.compute(shape, + lambda *vs: ExtractNonTopReductions( + fcompute(*vs), {v: tvm.Range(0, s) for v, s in zip(vs, shape)})) + +def check_tensor_symeq(A, B): + if not isinstance(B, tvm.tensor.Tensor): + B = compute(A.shape, B) + vmap = {a.var: b.var for a, b in zip(A.op.axis, B.op.axis)} + expr_a = tvm.ir_pass.Substitute(A.op.body[A.value_index], vmap) + expr_b = B.op.body[B.value_index] + expr_a = tvm.ir_pass.CanonicalSimplify(InlineTensors(expr_a, [], True)) + expr_b = tvm.ir_pass.CanonicalSimplify(InlineTensors(expr_b, [], True)) + if not Equal(expr_a, expr_b): + print(expr_a) + print(expr_b) + raise AssertionError("The expressions are not equal") + +def check_eq_bruteforce(expr1, expr2, vranges): + def _compute_body(*us, expr1=expr1, expr2=expr2, vranges=vranges): + vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} + return tvm.ir_pass.Substitute(expr1 == expr2, vmap) + + A = compute([r.extent.value for v, r in vranges.items()], _compute_body) + args = [tvm.ndarray.empty(A.shape, A.dtype)] + sch = tvm.create_schedule(A.op) + mod = tvm.build(sch, [A]) + mod(*args) + res = args[0].asnumpy() + if not np.all(res): + indices = list(np.argwhere(res == 0)[0]) + counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] + counterex = ", ".join([v + " = " + str(i) for v, i in sorted(counterex)]) + raise AssertionError("Expressions {}\nand {}\nare not equal on {}\n" + "Counterexample: {}" + .format(expr1, expr2, vranges, counterex)) + +prod_combiner = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) +sum_combiner = comm_reducer(lambda x, y: x + y, lambda t0: tvm.const(0, t0)) +sum2_combiner = comm_reducer(lambda x, y: y + x, lambda t0: tvm.const(0, t0)) +sum_derivative_combiner = comm_reducer(lambda x, y: (x[0] + y[0], y[1] + x[1]), + lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) +prod_derivative_combiner = comm_reducer(lambda x, y: (x[0]*y[0], x[0]*y[1] + x[1]*y[0]), + lambda t0, t1: (tvm.const(1, t0), tvm.const(0, t1))) +sum_both_combiner = comm_reducer(lambda x, y: (x[0] + y[0], x[0] + y[0] + x[1] + y[1]), + lambda t0, t1: (tvm.const(0, t0), tvm.const(0, t1))) +xor_combiner = comm_reducer(lambda x, y: x ^ y, lambda t0: tvm.const(0, t0)) + +def test_is_sum_combiner(): + k = tvm.reduce_axis((0, 10), name="k") + i = tvm.const(0, "int32") + f = tvm.const(0.0, "float32") + assert IsSumCombiner(sum_combiner(i, k).combiner) + assert IsSumCombiner(sum_combiner(f, k).combiner) + assert IsSumCombiner(sum2_combiner(i, k).combiner) + assert IsSumCombiner(sum2_combiner(f, k).combiner) + assert not IsSumCombiner(sum_derivative_combiner((f, f), k)[0].combiner) + assert not IsSumCombiner(prod_combiner(f, k).combiner) + assert not IsSumCombiner(prod_derivative_combiner((f, f), k)[1].combiner) + +def test_can_factor_zero_from_combiner(): + k = tvm.reduce_axis((0, 10), name="k") + i = tvm.const(0, "int32") + f = tvm.const(0.0, "float32") + assert CanFactorZeroFromCombiner(sum_combiner(i, k).combiner, 0) + assert CanFactorZeroFromCombiner(sum2_combiner(f, k).combiner, 0) + assert CanFactorZeroFromCombiner(sum_derivative_combiner((f, f), k)[0].combiner, 0) + assert CanFactorZeroFromCombiner(sum_derivative_combiner((f, f), k)[0].combiner, 1) + assert not CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 0) + assert CanFactorZeroFromCombiner(prod_derivative_combiner((f, f), k)[0].combiner, 1) + assert CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 0) + assert not CanFactorZeroFromCombiner(sum_both_combiner((f, f), k)[0].combiner, 1) + +def test_lift_nonzeroness_condition(): + k = tvm.reduce_axis((0, 5), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 5), name="n") + A = tvm.placeholder((10,), name='A') + + def _check(shape, fun, A=A): + T1 = tvm.compute(shape, fun) + T2 = tvm.compute(shape, lambda *args: LiftNonzeronessCondition(fun(*args))) + check_eq(T1, T2, [A]) + assert isinstance(T2.op.body[0], tvm.expr.Select) + + _check((10,), lambda i: A[i]) + _check((10,), lambda i: A[i] + (i % 2 == 0)) + _check((10,), lambda i: A[i]*(i % 2 == 0) + (i % 2 == 0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), A[i], 0.0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), A[i], 0.0) + (i % 2 == 0)) + _check((10,), lambda i: tvm.expr.Select((i % 2 == 0), 0.0, A[i]) + (i % 2 == 0)) + def e1(i): return tvm.expr.Select((i % 2 == 1), 0.0, A[i]) + def e2(i): return tvm.expr.Select((i % 2 == 0), A[(i + 1) % 10], 0.0) + def e3(i): return tvm.expr.Select((i % 2 == 1), A[i], 0.0) + _check((10,), lambda i: e1(i) + e2(i) + e3(i) + e1(i)*e2(i)) + _check((10,), lambda i: e1(i)*e3(i)) + _check((10,), lambda i: e1(i)*e2(i)) + _check((10,10), lambda i, j: A[i]*(i == j) + A[j]*(i == 2*j) + A[j]*(j == i)) + _check((10,10), lambda i, j: tvm.min(A[i]*(i == j), A[j]*(i == 2*j))) + _check((10,10), lambda i, j: tvm.max(A[i]*(i == j), A[j]*(i == 2*j))) + _check((10,10), lambda i, j: A[i]*(i == j) - A[j]*(i == 2*j)) + _check((10,10), lambda i, j: A[i]*(i == j) / (1 + tvm.abs(A[j]*(i == 2*j)))) + _check((10,10), lambda i, j: i*(i < j) + j*(i > j)) + _check((10,10), lambda i, j: i*(i < j) % (1 + j*(i > j))) + + def _check_symeq(expr1, expr2): + expr1 = LiftNonzeronessCondition(expr1) + expr2 = LiftNonzeronessCondition(expr2) + print(expr1) + print(expr2) + print() + check_symeq(expr1, expr2) + + _check_symeq(tvm.expr.Select(tvm.expr.EQ(k, l), 0.0, tvm.expr.Cast('float32', (k < n))), + tvm.expr.Select(tvm.expr.And((k < n), tvm.expr.NE(k, l)), 1.0, 0.0)) + _check_symeq(tvm.min(tvm.expr.Cast('int32', k < n)*l, tvm.expr.Select(k >= n, 0, 1)), + tvm.expr.Select(k < n, tvm.min(l, 1), 0)) + +def test_inline_tail_call(): + A = tvm.compute((10, 10), lambda i, j: i + j*j) + B = tvm.compute((5, 6), lambda k, l: A[k + l, k + 1]) + C = InlineTailCall(B) + resbody = lambda k, l: k + l + (k + 1)*(k + 1) + check_symeq(C.op.body[0], resbody(*[iv.var for iv in C.op.axis])) + +def test_inline_tensors(): + A = tvm.compute((10, 10), lambda i, j: i + j) + B = tvm.compute((10, 10), lambda i, j: i * j) + C = tvm.compute((10, 10), lambda i, j: A[i, j] + B[i, j]) + k = tvm.reduce_axis((0, 5), name="k") + D = tvm.compute((10, 10), lambda i, j: tvm.sum(A[i, k], k)) + E = tvm.compute((10, 10), lambda i, j: A[2, j] + C[i, 2] + D[i, j]) + + R = InlineTensors(E) + resbody = lambda i, j: 2 + j + i + 2 + i*2 + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [A]) + resbody = lambda i, j: 2 + j + C[i, 2] + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [A, C]) + resbody = lambda i, j: 2 + j + ((i + 2) + B[i, 2]) + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + + R = InlineTensors(E, [B, C]) + resbody = lambda i, j: A[2, j] + (A[i, 2] + i*2) + D[i, j] + check_symeq(R.op.body[0], resbody(*[iv.var for iv in R.op.axis])) + +def test_solve_system_of_inequalities(): + seed = random.randrange(sys.maxsize) + print("\nseed: {}\n".format(seed)) + random.seed(seed) + + def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): + vs = [tvm.var("x" + str(i)) for i in range(variables)] + + fs = [] + for i in range(formulas): + s1 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s1 += random.randint(coef[0], coef[1]) + s2 = sum([v*random.randint(coef[0], coef[1]) for v in vs]) + s2 += random.randint(coef[0], coef[1]) + op = random.choice([tvm.expr.EQ, tvm.expr.LE, tvm.expr.LT, tvm.expr.GE, tvm.expr.GT]) + fs.append(op(s1, s2)) + + vranges = {v: tvm.Range(bounds[0], bounds[1] + 1) for v in vs} + + before = tvm.all(*fs) + print(before) + after = tvm.all(*SolveSystemOfInequalities(fs, vs, vranges)) + print(after) + print() + + check_eq_bruteforce(before, after, vranges) + + for i in range(3): + _check(1, 1) + for i in range(3): + _check(1, 2) + + for i in range(3): + _check(2, 1) + for i in range(3): + _check(2, 2) + for i in range(3): + _check(2, 3) + + for i in range(5): + _check(3, 3) + for i in range(5): + _check(3, 4) + + # Somewhere here coefficients in the results become too large, leading to overflow, + # so we use smaller initial coefficients + + for i in range(5): + _check(4, 3, coef=(-1,1)) + + for i in range(5): + _check(10, 2, coef=(-1,1), bounds=(0, 4)) + for i in range(5): + _check(10, 3, coef=(0,1), bounds=(0, 4)) + +def test_simplify_domain(): + """Note that here we test both SimplifyDomain and SimplifyReductionDomain.""" + def _check(cond, axis, volume, vranges={}): + vranges_with_axis = dict(vranges) + vranges_with_axis.update({iv.var: iv.dom for iv in axis}) + variables = [iv.var for iv in axis] + new_cond, new_axis, old_to_new, new_to_old = SimplifyDomain(cond, variables, + vranges_with_axis) + + print("old", axis, cond) + print("new", new_axis, new_cond) + print("old_to_new", old_to_new) + print("new_to_old", new_to_old) + print() + + cond_subst = tvm.ir_pass.Substitute(cond, old_to_new) + new_vranges = vranges.copy() + new_vranges.update({v.var: v.dom for v in new_axis}) + # If new_cond is true in the new domain, then cond_subst must also be true in the new + # domain, but the reverse is not necessarily true + check_eq_bruteforce(tvm.all(new_cond, cond_subst), new_cond, new_vranges) + + new_cond_subst = tvm.ir_pass.Substitute(new_cond, new_to_old) + old_vranges = vranges.copy() + old_vranges.update({v.var: v.dom for v in axis}) + check_eq_bruteforce(cond, tvm.all(cond, new_cond_subst), old_vranges) + + # Also check SimplifyReductionDomain + reduction = xor_combiner(sum([v*(i + 1) for i, v in enumerate(axis)]), axis) + new_reduction = SimplifyReductionDomain(reduction, vranges) + check_eq_bruteforce(reduction, new_reduction, vranges) + + vol = np.prod([iv.dom.extent.value for iv in new_axis]) + if vol != volume: + raise AssertionError("New volume is {} != {}\n" + "Old domain {} where {}\nNew domain {} where {}" + .format(vol, volume, axis, cond, new_axis, new_cond)) + + k = tvm.reduce_axis((0, 5), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 5), name="n") + + _check((k <= l), [k, l, n], 125) + _check((k < l), [k, l, n], 80) + _check(tvm.expr.EQ(k, l), [k, l, n], 25) + _check(tvm.all(tvm.expr.EQ(k, l), (l < n)), [k, l, n], 16) + _check(tvm.expr.EQ(2*l, k), [k, l, n], 15) + # TODO: the result depends on the order of variables because we don't have a proper solver for + # systems of linear equations yet + _check(tvm.expr.EQ(2*l, k), [n, l, k], 25) + _check(tvm.all(l - k < 2, 2*n == k), [k, l, n], 15) + _check(tvm.all(l - k < 2, l >= k), [k, l, n], 50) + + some_var = tvm.var('some_var') + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 50, {some_var: tvm.Range(0, 3)}) + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 25, {some_var: tvm.Range(0, 2)}) + + + k = tvm.reduce_axis((-3, 2), name="k") + l = tvm.reduce_axis((-3, 2), name="l") + n = tvm.reduce_axis((-3, 2), name="n") + + _check((k < l), [k, l, n], 80) + _check(tvm.expr.EQ(k, l), [k, l, n], 25) + _check(tvm.all(tvm.expr.EQ(k, l), (l < n)), [k, l, n], 16) + # Now there are only two possible values for l: {l = -1, k = -2} and {l = 0, k = 0} + _check(tvm.expr.EQ(2*l, k), [k, l, n], 10) + # TODO: the result depends on the order of variables because we don't have a proper solver for + # systems of linear equations + _check(tvm.expr.EQ(2*l, k), [n, l, k], 25) + _check(tvm.all(l - k < 2, 2*n == k), [k, l, n], 10) + _check(tvm.all(l - k < 2, l >= k), [k, l, n], 50) + + some_var = tvm.var('some_var') + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 50, {some_var: tvm.Range(0, 3)}) + _check(tvm.all(l - k < some_var, l >= k), [k, l, n], 25, {some_var: tvm.Range(0, 2)}) + + + k = tvm.reduce_axis((0, 6), name="k") + l = tvm.reduce_axis((0, 5), name="l") + n = tvm.reduce_axis((0, 30), name="n") + + _check(tvm.all(k + l*6 == n), [k, l, n], 30) + _check(tvm.all(k + l*6 == n), [n, k, l], 30) + _check(tvm.all(k + l*6 == n), [n, l, k], 30) + + _check(tvm.all(n / 5 == k, n % 5 == l), [l, k, n], 30) + # TODO: Same thing with the order + _check(tvm.all(n / 5 == k, n % 5 == l), [n, l, k], 30) + + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + # TODO: This is not fully optimized because we don't have a solver + _check(tvm.all((l + k)%3 <= 1, (l + k)/3 <= 2), [l, k], 144) + +def test_extract_as_tensor_maybe(): + def _check(shape, fcompute, volume=None, vranges={}): + def fcompute_extracted(*variables, fcompute=fcompute, volume=volume, + vranges=vranges, shape=shape): + vranges = dict(vranges) + vranges.update({v: tvm.Range(0, s) for v, s in zip(variables, shape)}) + expr = fcompute(*variables) + if isinstance(expr, tvm.expr.Select): + new_true_value = ExtractAsTensorMaybe(expr.true_value, + expr.condition, + variables, + vranges) + expr = tvm.expr.Select(expr.condition, + new_true_value, + expr.false_value) + if volume is not None: + assert isinstance(new_true_value, tvm.expr.Call) + vol = np.prod([iv.dom.extent.value for iv in new_true_value.func.axis]) + if vol != volume: + raise AssertionError("New volume is {} != {}" + .format(vol, volume)) + return expr + + A = tvm.compute(shape, fcompute) + B = tvm.compute(shape, fcompute_extracted) + check_eq(A, B, []) + + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, i + j, 0), volume=30) + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, j, 0), volume=10) + _check((10, 10), lambda i, j: tvm.expr.Select(i < 3, i, 0), volume=3) + _check((10, 10), lambda i, j: tvm.expr.Select(tvm.all(i < j, j < 5), i + j, 0), volume=16) + # This one doesn't get extracted + _check((10, 10), lambda i, j: tvm.expr.Select(i <= j, i + j, 0)) + +def test_extract_reductions(): + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + n = tvm.reduce_axis((0, 10), name="n") + + A = tvm.compute((10, 10), + lambda i, j: + ExtractReductions(sum_combiner(i + k + xor_combiner(j*k + l, l), k), + {i: tvm.Range(0, 10), j: tvm.Range(0, 10)})) + B = tvm.compute((10, 10), lambda j, k: xor_combiner(j*k + l, l)) + C = tvm.compute((10, 10), lambda i, j: sum_combiner(i + k + B[j, k], k)) + check_eq(C, A, []) + + fcompute = lambda i, j: \ + ExtractReductions(sum_both_combiner((prod_derivative_combiner((i*n + 2*k, j + k), k)[1], + xor_combiner(j*n + l, l)), n)[1], + {i: tvm.Range(0, 10), j: tvm.Range(0, 10)}) + A = tvm.compute((10, 10), fcompute) + _, B = tvm.compute((10, 10, 10), + lambda i, j, n: prod_derivative_combiner((i*n + 2*k, j + k), k)) + C = tvm.compute((10, 10), lambda j, n: xor_combiner(j*n + l, l)) + _, D = tvm.compute((10, 10), lambda i, j: sum_both_combiner((B[i, j, n], C[j, n]), n)) + check_eq(A, D, []) + +def test_optimize_and_lift_nonzeroness(): + k = tvm.reduce_axis((0, 10), name="k") + l = tvm.reduce_axis((0, 10), name="l") + n = tvm.reduce_axis((0, 10), name="n") + A = tvm.placeholder((10, 10), name="A") + + zero = tvm.const(0, 'float32') + + B = compute((10, 10), lambda i, j: tvm.sum((i == j)*A[i, k] + A[k, j]*(i == j), k)) + B = OptimizeAndLiftNonzeronessConditions(B) + R = lambda i, j: tvm.expr.Select(i == j, + tvm.sum(A[j, k] + A[k, j], k), + zero) + check_tensor_symeq(B, R) + + B = compute((10, 10), lambda i, j: tvm.sum((i == j)*(i == k)*A[i, k] + + (i == j)*A[k, j]*(i == k), k)) + B = OptimizeAndLiftNonzeronessConditions(B) + R = lambda i, j: tvm.expr.Select(i == j, A[j, j]*2.0, zero) + check_tensor_symeq(B, R) + + B = compute((10, 10), lambda i, j: tvm.sum((i < j)*(j < k)*A[j, k], k)) + B = OptimizeAndLiftNonzeronessConditions(B) + k1 = tvm.reduce_axis((2, 10), name="k1") + R = compute((10, 10), lambda i, j: + tvm.expr.Select(tvm.all(i < j, j < 10), + tvm.sum(tvm.expr.Select(j < k1, A[j, k1], zero), k1), + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + + # TODO: This one needs the equation solver + # B = compute((10, 10), lambda i, j: tvm.sum((i <= j)*(j <= k)*A[j, k], k, where=(i >= k))) + # B = OptimizeAndLiftNonzeronessConditions(B) + # R = compute((10, 10), lambda i, j: tvm.expr.Select((i == j), A[i, i], zero)) + # check_eq(B, R, [A]) + # assert estimate_performance(B) <= estimate_performance(R) + + B = compute((10, 10), + lambda i, j: prod_derivative_combiner((A[j, k], (i <= j)*(j < k)*A[i, k]), k)[1]) + B = OptimizeAndLiftNonzeronessConditions(B) + R = compute((10, 10), lambda i, j: + tvm.expr.Select(tvm.all(i <= j, j < 10), + prod_derivative_combiner((A[j, k], (j < k)*A[i, k]), k)[1], + zero)) + check_eq(B, R, [A]) + assert estimate_performance(B) <= estimate_performance(R) + +if __name__ == "__main__": + test_is_sum_combiner() + test_can_factor_zero_from_combiner() + test_lift_nonzeroness_condition() + test_inline_tail_call() + test_inline_tensors() + test_solve_system_of_inequalities() + test_simplify_domain() + test_extract_as_tensor_maybe() + test_extract_reductions() + test_optimize_and_lift_nonzeroness() From 0de13b9adae10586cc79c1b37a64b5f1c4c11aa1 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 23 Jan 2019 19:10:08 +0300 Subject: [PATCH 02/11] Fix the failing tests --- tests/python/unittest/test_pass_autodiff.py | 145 +++++++++--------- .../unittest/test_pass_zero_elimination.py | 25 +-- 2 files changed, 84 insertions(+), 86 deletions(-) diff --git a/tests/python/unittest/test_pass_autodiff.py b/tests/python/unittest/test_pass_autodiff.py index 6c971215e796..e27fb517b2d3 100644 --- a/tests/python/unittest/test_pass_autodiff.py +++ b/tests/python/unittest/test_pass_autodiff.py @@ -1,6 +1,3 @@ -# This example demonstrates Automatic Differentiation for TVM basic operations and TOPI primitives. -# See `test_autodiff()` and `test_nn_autodiff()` for details. - import tvm import topi import numpy as np @@ -40,7 +37,7 @@ def check_equivalence(outputs1, outputs2, inputs, in_range=(-10, 10), iters=10): for j, _ in enumerate(outputs1): tvm.testing.assert_allclose(arguments1[j].asnumpy(), arguments2[j].asnumpy()) -def test_grad(out, inputs, args=[], in_range=(-10,10), perf=None): +def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None): line = inspect.getframeinfo(inspect.stack()[1][0]).lineno if not isinstance(inputs, (list, tuple)): @@ -169,40 +166,40 @@ def test_autodiff(): A1 = tvm.placeholder((10, 10), name='A1') B = tvm.compute((10, 10), lambda i, j: A0[i, j] + A0[j, i], name='B') - test_grad(B, A0, perf=(10100, 10000, 200)) + check_grad(B, A0, perf=(10100, 10000, 200)) B = tvm.compute((10, 10), lambda i, j: A0[i, j] + tvm.exp(A0[j, i]), name='B') - test_grad(B, A0, perf=(10100, 20000, 200)) + check_grad(B, A0, perf=(10100, 20000, 200)) B = tvm.compute((10, 10), lambda i, j: tvm.log(tvm.abs(A0[i, j] + tvm.exp(A0[j, i]))), name='B') - test_grad(B, A0, perf=(10100, 70000, 200)) + check_grad(B, A0, perf=(10100, 70000, 200)) B = tvm.compute((10, 10), lambda i, j: tvm.sigmoid(A0[i, j]*A0[i, j]*A0[j, i]), name='B') - test_grad(B, A0, perf=(10100, 120000, 200)) + check_grad(B, A0, perf=(10100, 120000, 200)) B = tvm.compute((10, 10), lambda i, j: tvm.tanh(A0[i, j]*A0[i, j]*A0[j, i]), name='B') - test_grad(B, A0, perf=(10100, 120000, 200)) + check_grad(B, A0, perf=(10100, 120000, 200)) B = tvm.compute((10, 10), lambda i, j: A0[i, j] * A0[j, i], name='B') - test_grad(B, A0, perf=(10100, 10000, 200)) + check_grad(B, A0, perf=(10100, 10000, 200)) # TODO: This one needs transforming Sum(a + b) -> Sum(a) + Sum(b) B = tvm.compute((10,), lambda i: tvm.sum(A0[i, k]*A0[k, i], axis=k), name='B') - test_grad(B, A0, perf=(11010, 1000, 1110)) + check_grad(B, A0, perf=(11010, 1000, 1110)) B = tvm.compute((10, 10), lambda i, j: tvm.sum(A0[i, k]*A0[k, i] + 5, axis=k), name='B') - test_grad(B, A0, perf=(20100, 10000, 1200)) + check_grad(B, A0, perf=(20100, 10000, 1200)) B = tvm.compute((10, 10), lambda i, j: tvm.max(A0[i, k]*A0[k, j] + 5, axis=k), name='B') - test_grad(B, A0, perf=(110100, 310000, 20200)) + check_grad(B, A0, perf=(110100, 310000, 20200)) B = tvm.compute((10, 10), lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]), name='B') - test_grad(B, A0, [A1], perf=(10100, 10000, 200)) + check_grad(B, A0, [A1], perf=(10100, 10000, 200)) B = tvm.compute((10, 10), lambda i, j: tvm.sum(A0[k, k] - A0[tvm.min(j + k, 9), j]*A0[i, k], axis=k), name='B') - test_grad(B, A0, perf=(110100, 10000, 10200)) + check_grad(B, A0, perf=(110100, 10000, 10200)) def fcombine(x, y): return x*y @@ -212,13 +209,13 @@ def fidentity(t0): prod = tvm.comm_reducer(fcombine, fidentity, name='prod') B = tvm.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name='B') - test_grad(B, A0, perf=(20100, 40000, 2200)) + check_grad(B, A0, perf=(20100, 40000, 2200)) X = tvm.placeholder((10,), name='X') A = tvm.compute((10,), lambda i: X[i] + X[9 - i]) B = tvm.compute((10,), lambda i: X[i] * X[9 - i]) Y = topi.tensordot(A, B, 1) - test_grad(Y, X, perf=(251, 230, 71)) + check_grad(Y, X, perf=(251, 230, 71)) def test_topi_autodiff(): X = tvm.placeholder((1, 2, 4, 4), name='X') @@ -227,62 +224,62 @@ def test_topi_autodiff(): W2 = tvm.placeholder((1,), name='W1') R = topi.nn.conv2d(X, W, 1, 1, 1) - test_grad(R, [X, W], perf=(3410, 2880, 652)) + check_grad(R, [X, W], perf=(3410, 2880, 652)) R1 = topi.nn.conv2d(topi.nn.relu(R), W1, 1, 0, 1) - test_grad(R1, [X, W, W1], perf=(6198, 5320, 1250)) + check_grad(R1, [X, W, W1], perf=(6198, 5320, 1250)) R = topi.broadcast_to(W2, (5, 2, 3, 3)) - test_grad(R, [W2], perf=(180, 0, 91)) + check_grad(R, [W2], perf=(180, 0, 91)) R = topi.nn.conv2d(X, topi.broadcast_to(W2, (5, 2, 3, 3)), 1, 1, 1) - test_grad(R, [X, W2], perf=(3590, 2880, 743)) + check_grad(R, [X, W2], perf=(3590, 2880, 743)) R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'avg') - test_grad(R, X, perf=(40, 224, 40)) + check_grad(R, X, perf=(40, 224, 40)) R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') - test_grad(R, X, perf=(168, 1248, 104)) + check_grad(R, X, perf=(168, 1248, 104)) X = tvm.placeholder((1, 2, 5, 5), name='X') R = topi.reshape(X, (1, 32)) - test_grad(R, [X], perf=(82, 1200, 82)) + check_grad(R, [X], perf=(82, 1200, 82)) X = tvm.placeholder((1, 2, 5, 5), name='X') W = tvm.placeholder((2, 2, 3, 3), name='W') S = topi.reshape(X, (1, 50)) - test_grad(S, [X], perf=(100, 700, 100)) + check_grad(S, [X], perf=(100, 700, 100)) R = X + topi.nn.conv2d(X + topi.nn.conv2d(X, W, 1, 1, 1), W, 1, 1, 1) - test_grad(R, [X, W], perf=(6854, 5400, 1726)) + check_grad(R, [X, W], perf=(6854, 5400, 1726)) S = topi.nn.softmax(topi.reshape(R, (1, 50))) - test_grad(S, [X, W], perf=(10956, 14201, 2333)) + check_grad(S, [X, W], perf=(10956, 14201, 2333)) S = topi.sigmoid(topi.reshape(R, (1, 50))) - test_grad(S, [X, W], perf=(8004, 8350, 2026)) + check_grad(S, [X, W], perf=(8004, 8350, 2026)) S = topi.tanh(topi.reshape(R, (1, 50))) - test_grad(S, [X, W], perf=(8004, 8350, 2026)) + check_grad(S, [X, W], perf=(8004, 8350, 2026)) S = topi.nn.log_softmax(topi.reshape(R, (1, 50))) - test_grad(S, [X, W], perf=(10906, 13601, 2283)) - test_grad(S, [W], [X], perf=(8920, 11101, 1997)) + check_grad(S, [X, W], perf=(10906, 13601, 2283)) + check_grad(S, [W], [X], perf=(8920, 11101, 1997)) X = tvm.placeholder((1, 2, 3, 5), name='X') Y = tvm.placeholder((1, 2, 7, 5), name='Y') S = topi.concatenate((X, Y), 2) - test_grad(S, [X, Y], perf=(100, 0, 100)) + check_grad(S, [X, Y], perf=(100, 0, 100)) X = tvm.placeholder((1, 2, 6, 5), name='X') (S, R) = topi.split(X, 2, 2) - test_grad(S, [X], perf=(120, 0, 120)) - test_grad(R, [X], perf=(120, 0, 120)) + check_grad(S, [X], perf=(120, 0, 120)) + check_grad(R, [X], perf=(120, 0, 120)) R1 = topi.concatenate((S, R), 2) - test_grad(R1, [X], perf=(300, 0, 300)) + check_grad(R1, [X], perf=(300, 0, 300)) R2 = topi.concatenate((R, S), 2) - test_grad(R2, [X], perf=(300, 0, 300)) + check_grad(R2, [X], perf=(300, 0, 300)) def test_stride_dilation(): X = tvm.placeholder((1, 2, 10, 10), name='X') @@ -290,84 +287,84 @@ def test_stride_dilation(): W = tvm.placeholder((2, 2, 1, 1), name='W') Y = topi.nn.conv2d(X, W, 1, 0, 1) - test_grad(Y, [X, W], perf=(1404, 800, 808)) + check_grad(Y, [X, W], perf=(1404, 800, 808)) Y = topi.nn.conv2d(X, W, 2, 0, 1) - test_grad(Y, [X, W], perf=(928, 1572, 670)) + check_grad(Y, [X, W], perf=(928, 1572, 670)) Y = topi.nn.conv2d(X, W, 3, 0, 1) - test_grad(Y, [X, W], perf=(932, 1728, 672)) + check_grad(Y, [X, W], perf=(932, 1728, 672)) Y = topi.nn.conv2d(X, W, 1, 0, 2) - test_grad(Y, [X, W], perf=(1404, 800, 808)) + check_grad(Y, [X, W], perf=(1404, 800, 808)) Y = topi.nn.conv2d(X, W, 2, 0, 2) - test_grad(Y, [X, W], perf=(928, 1572, 670)) + check_grad(Y, [X, W], perf=(928, 1572, 670)) Y = topi.nn.conv2d(X, W, 3, 0, 2) - test_grad(Y, [X, W], perf=(932, 1728, 672)) + check_grad(Y, [X, W], perf=(932, 1728, 672)) Y = topi.nn.conv2d(X, W, 1, 0, 3) - test_grad(Y, [X, W], perf=(1404, 800, 808)) + check_grad(Y, [X, W], perf=(1404, 800, 808)) Y = topi.nn.conv2d(X, W, 2, 0, 3) - test_grad(Y, [X, W], perf=(928, 1572, 670)) + check_grad(Y, [X, W], perf=(928, 1572, 670)) Y = topi.nn.conv2d(X, W, 3, 0, 3) - test_grad(Y, [X, W], perf=(932, 1728, 672)) + check_grad(Y, [X, W], perf=(932, 1728, 672)) W = tvm.placeholder((2, 2, 2, 2), name='W') Y = topi.nn.conv2d(X, W, 1, 0, 1) - test_grad(Y, [X, W], perf=(3922, 2896, 1242)) + check_grad(Y, [X, W], perf=(3922, 2896, 1242)) Y = topi.nn.conv2d(X, W, 2, 0, 1) - test_grad(Y, [X, W], perf=(1650, 2800, 1066)) + check_grad(Y, [X, W], perf=(1650, 2800, 1066)) Y = topi.nn.conv2d(X, W, 3, 0, 1) - test_grad(Y, [X, W], perf=(1146, 2880, 890)) + check_grad(Y, [X, W], perf=(1146, 2880, 890)) Y = topi.nn.conv2d(X, W, 1, 0, 2) - test_grad(Y, [X, W], perf=(3500, 19720, 1092)) + check_grad(Y, [X, W], perf=(3500, 19720, 1092)) Y = topi.nn.conv2d(X, W, 2, 0, 2) - test_grad(Y, [X, W], perf=(3408, 88848, 2034)) + check_grad(Y, [X, W], perf=(3408, 88848, 2034)) Y = topi.nn.conv2d(X, W, 3, 0, 2) - test_grad(Y, [X, W], perf=(2254, 65232, 992)) + check_grad(Y, [X, W], perf=(2254, 65232, 992)) Y = topi.nn.conv2d(X, W, 1, 0, 3) - test_grad(Y, [X, W], perf=(3138, 17696, 970)) + check_grad(Y, [X, W], perf=(3138, 17696, 970)) Y = topi.nn.conv2d(X, W, 2, 0, 3) - test_grad(Y, [X, W], perf=(3816, 82368, 2176)) + check_grad(Y, [X, W], perf=(3816, 82368, 2176)) Y = topi.nn.conv2d(X, W, 3, 0, 3) - test_grad(Y, [X, W], perf=(3834, 104432, 2306)) + check_grad(Y, [X, W], perf=(3834, 104432, 2306)) W = tvm.placeholder((2, 2, 3, 3), name='W') Y = topi.nn.conv2d(X, W, 1, 0, 1) - test_grad(Y, [X, W], perf=(7420, 5904, 1752)) + check_grad(Y, [X, W], perf=(7420, 5904, 1752)) Y = topi.nn.conv2d(X, W, 2, 0, 1) - test_grad(Y, [X, W], perf=(3888, 58592, 2214)) + check_grad(Y, [X, W], perf=(3888, 58592, 2214)) Y = topi.nn.conv2d(X, W, 3, 0, 1) - test_grad(Y, [X, W], perf=(1552, 2268, 1102)) + check_grad(Y, [X, W], perf=(1552, 2268, 1102)) Y = topi.nn.conv2d(X, W, 1, 0, 2) - test_grad(Y, [X, W], perf=(5916, 42392, 1256)) + check_grad(Y, [X, W], perf=(5916, 42392, 1256)) Y = topi.nn.conv2d(X, W, 2, 0, 2) - test_grad(Y, [X, W], perf=(6736, 21784, 3694)) + check_grad(Y, [X, W], perf=(6736, 21784, 3694)) Y = topi.nn.conv2d(X, W, 3, 0, 2) - test_grad(Y, [X, W], perf=(2672, 146096, 1668)) + check_grad(Y, [X, W], perf=(2672, 146096, 1668)) Y = topi.nn.conv2d(X, W, 1, 0, 3) - test_grad(Y, [X, W], perf=(2896, 89152, 956)) + check_grad(Y, [X, W], perf=(2896, 89152, 956)) Y = topi.nn.conv2d(X, W, 2, 0, 3) - test_grad(Y, [X, W], perf=(2280, 12856, 1992)) + check_grad(Y, [X, W], perf=(2280, 12856, 1992)) Y = topi.nn.conv2d(X, W, 3, 0, 3) - test_grad(Y, [X, W], perf=(2224, 12032, 716)) + check_grad(Y, [X, W], perf=(2224, 12032, 716)) Y = topi.nn.pool(X, [1, 1], [1, 1], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(200, 0, 200)) + check_grad(Y, [X], perf=(200, 0, 200)) Y = topi.nn.pool(X, [1, 1], [2, 2], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(412, 1124, 412)) + check_grad(Y, [X], perf=(412, 1124, 412)) Y = topi.nn.pool(X, [1, 1], [3, 3], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(232, 1200, 232)) + check_grad(Y, [X], perf=(232, 1200, 232)) Y = topi.nn.pool(X, [2, 2], [1, 1], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(4162, 7200, 1962)) + check_grad(Y, [X], perf=(4162, 7200, 1962)) Y = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(1050, 7800, 650)) + check_grad(Y, [X], perf=(1050, 7800, 650)) Y = topi.nn.pool(X, [2, 2], [3, 3], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(858, 6304, 602)) + check_grad(Y, [X], perf=(858, 6304, 602)) Y = topi.nn.pool(X, [3, 3], [1, 1], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(18128, 34200, 3928)) + check_grad(Y, [X], perf=(18128, 34200, 3928)) Y = topi.nn.pool(X, [3, 3], [2, 2], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(6712, 131312, 1690)) + check_grad(Y, [X], perf=(6712, 131312, 1690)) Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], 'max') - test_grad(Y, [X], perf=(1838, 12950, 704)) + check_grad(Y, [X], perf=(1838, 12950, 704)) def test_some_conv2d_net(): batch_size = 1 @@ -401,13 +398,13 @@ def test_some_conv2d_net(): weights = [w1, b1, w2, b2, w3, b3, w4, b4] - test_grad(t, weights, [x, y], in_range=(-1.0, 1.0), perf=(194865, 179089, 28194)) + check_grad(t, weights, [x, y], in_range=(-1.0, 1.0), perf=(194865, 179089, 28194)) if __name__ == "__main__": if "-v" in sys.argv: verbose = True - # test_differentiate_function() + test_differentiate_function() test_autodiff() test_topi_autodiff() test_stride_dilation() diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py index 8e4bd59394b5..9eab3a9c6a87 100644 --- a/tests/python/unittest/test_pass_zero_elimination.py +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -62,7 +62,7 @@ def check_tensor_symeq(A, B): raise AssertionError("The expressions are not equal") def check_eq_bruteforce(expr1, expr2, vranges): - def _compute_body(*us, expr1=expr1, expr2=expr2, vranges=vranges): + def _compute_body(*us): vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} return tvm.ir_pass.Substitute(expr1 == expr2, vmap) @@ -248,7 +248,7 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): _check(10, 3, coef=(0,1), bounds=(0, 4)) def test_simplify_domain(): - """Note that here we test both SimplifyDomain and SimplifyReductionDomain.""" + # Note that here we test both SimplifyDomain and SimplifyReductionDomain. def _check(cond, axis, volume, vranges={}): vranges_with_axis = dict(vranges) vranges_with_axis.update({iv.var: iv.dom for iv in axis}) @@ -344,16 +344,15 @@ def _check(cond, axis, volume, vranges={}): def test_extract_as_tensor_maybe(): def _check(shape, fcompute, volume=None, vranges={}): - def fcompute_extracted(*variables, fcompute=fcompute, volume=volume, - vranges=vranges, shape=shape): - vranges = dict(vranges) - vranges.update({v: tvm.Range(0, s) for v, s in zip(variables, shape)}) + def fcompute_extracted(*variables): + vranges_updated = dict(vranges) + vranges_updated.update({v: tvm.Range(0, s) for v, s in zip(variables, shape)}) expr = fcompute(*variables) if isinstance(expr, tvm.expr.Select): new_true_value = ExtractAsTensorMaybe(expr.true_value, expr.condition, variables, - vranges) + vranges_updated) expr = tvm.expr.Select(expr.condition, new_true_value, expr.false_value) @@ -415,11 +414,13 @@ def test_optimize_and_lift_nonzeroness(): zero) check_tensor_symeq(B, R) - B = compute((10, 10), lambda i, j: tvm.sum((i == j)*(i == k)*A[i, k] + - (i == j)*A[k, j]*(i == k), k)) - B = OptimizeAndLiftNonzeronessConditions(B) - R = lambda i, j: tvm.expr.Select(i == j, A[j, j]*2.0, zero) - check_tensor_symeq(B, R) + # TODO: This test is unstable: sometimes the resulting condition looks like + # (i == j)*(j == i) instead of (i == j) + # B = compute((10, 10), lambda i, j: tvm.sum((i == j)*(i == k)*A[i, k] + + # (i == j)*A[k, j]*(i == k), k)) + # B = OptimizeAndLiftNonzeronessConditions(B) + # R = lambda i, j: tvm.expr.Select(i == j, A[j, j]*2.0, zero) + # check_tensor_symeq(B, R) B = compute((10, 10), lambda i, j: tvm.sum((i < j)*(j < k)*A[j, k], k)) B = OptimizeAndLiftNonzeronessConditions(B) From 9f99cd742f1d39038773bb12b51180f84498930b Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 29 Jan 2019 13:40:04 +0300 Subject: [PATCH 03/11] Fix topi.take --- src/pass/zero_elimination.cc | 40 ++++++++++++++------- tests/python/unittest/test_pass_autodiff.py | 26 +++++++++----- 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 327d8a47032e..3a280b63d5af 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -653,7 +653,11 @@ class EliminateDivModMutator : public IRMutator { } Expr mutated_a = Mutate(op->a); - return AddNewVarPair(op->a, mutated_a, imm->value).first; + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { + return var_pair_opt.value().first; + } else { + return Div::make(mutated_a, Mutate(op->b)); + } } return Div::make(Mutate(op->a), Mutate(op->b)); @@ -668,17 +672,37 @@ class EliminateDivModMutator : public IRMutator { } Expr mutated_a = Mutate(op->a); - return AddNewVarPair(op->a, mutated_a, imm->value).second; + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value)) { + return var_pair_opt.value().second; + } else { + return Mod::make(mutated_a, Mutate(op->b)); + } } return Mod::make(Mutate(op->a), Mutate(op->b)); } private: - std::pair AddNewVarPair(const Expr& e, const Expr& mut, int64_t val) { + dmlc::optional> AddNewVarPair(const Expr& e, const Expr& mut, int64_t val) { + using tresult = dmlc::optional>; + Expr val_e = make_const(e.type(), val); idx_ += 1; + std::unordered_map var_intsets; + for (const auto& p : ranges) { + var_intsets[p.first.get()] = IntSet::range(p.second); + } + + Range div_range = EvalSet(mut / val_e, var_intsets).cover_range(Range()); + Range mod_range = EvalSet(mut % val_e, var_intsets).cover_range(Range()); + + if (!div_range.get() || !mod_range.get()) { + LOG(WARNING) << "EliminateDivMod: won't eliminate div or mod of expr " << e + << " because its bounds cannot be inferred"; + return tresult(); + } + auto div = Var("div" + std::to_string(idx_), e.type()); auto mod = Var("mod" + std::to_string(idx_), e.type()); @@ -688,14 +712,6 @@ class EliminateDivModMutator : public IRMutator { substitution.Set(div, mut / val_e); substitution.Set(mod, mut % val_e); - std::unordered_map var_intsets; - for (const auto& p : ranges) { - var_intsets[p.first.get()] = IntSet::range(p.second); - } - - Range div_range = EvalSet(mut / val_e, var_intsets).cover_range(Range()); - Range mod_range = EvalSet(mut % val_e, var_intsets).cover_range(Range()); - ranges.Set(div, div_range); ranges.Set(mod, mod_range); @@ -710,7 +726,7 @@ class EliminateDivModMutator : public IRMutator { auto p = std::make_pair(div, mod); expr_to_vars_[{e.get(), val}] = p; - return p; + return tresult(p); } int idx_{0}; diff --git a/tests/python/unittest/test_pass_autodiff.py b/tests/python/unittest/test_pass_autodiff.py index e27fb517b2d3..ea9c703f32f2 100644 --- a/tests/python/unittest/test_pass_autodiff.py +++ b/tests/python/unittest/test_pass_autodiff.py @@ -115,20 +115,25 @@ def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None): "worse than the reference ones (by {}): estimated {}, expected {}" .format(line, EST_RTOL, est.as_tuple(), ref_est.as_tuple())) - def fun(*arguments): - arrays = [tvm.nd.empty(get_shape(out), out.dtype)] + [tvm.nd.array(a) for a in arguments] - mout(*arrays) - return arrays[0].asnumpy().sum() - + input_vals = [tvm.nd.array(np.random.uniform(in_range[0], in_range[1], + size=get_shape(a)).astype(a.dtype)) + for a in inputs] arg_vals = [tvm.nd.array(np.random.uniform(in_range[0], in_range[1], size=get_shape(a)).astype(a.dtype)) - for a in inputs + args] + for a in args] - g_arg_vals = [tvm.nd.empty(get_shape(i), g.dtype) for i, g in zip(inputs, grads)] + arg_vals + def fun(*arguments, arg_vals=arg_vals): + arrays = [tvm.nd.empty(get_shape(out), out.dtype)] + \ + [tvm.nd.array(a) for a in list(arguments) + arg_vals] + mout(*arrays) + return arrays[0].asnumpy().sum() + + g_arg_vals = [tvm.nd.empty(get_shape(i), g.dtype) for i, g in zip(inputs, grads)] + \ + input_vals + arg_vals mgrad(*g_arg_vals) g_res = [g_arg_vals[g].asnumpy() for g, _ in enumerate(grads)] - check_numerical_grads(fun, [a.asnumpy() for a in arg_vals], g_res) + check_numerical_grads(fun, [a.asnumpy() for a in input_vals], g_res) def test_differentiate_function(): x = tvm.placeholder((32, 3, 28, 28), name='x') @@ -281,6 +286,11 @@ def test_topi_autodiff(): R2 = topi.concatenate((R, S), 2) check_grad(R2, [X], perf=(300, 0, 300)) + X = tvm.placeholder((4, 5), name='X') + I = tvm.placeholder((100,), name='I', dtype='int32') + R = topi.take(X, topi.abs(I)) + check_grad(R, [X], [I], perf=(2200, 6000, 220)) + def test_stride_dilation(): X = tvm.placeholder((1, 2, 10, 10), name='X') From 0775a4e95bf6033e002ca86164529954b66c92d9 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 29 Jan 2019 14:54:55 +0300 Subject: [PATCH 04/11] Fix a python2 syntax error --- tests/python/unittest/test_pass_autodiff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_pass_autodiff.py b/tests/python/unittest/test_pass_autodiff.py index ea9c703f32f2..42e83c58281c 100644 --- a/tests/python/unittest/test_pass_autodiff.py +++ b/tests/python/unittest/test_pass_autodiff.py @@ -122,7 +122,7 @@ def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None): size=get_shape(a)).astype(a.dtype)) for a in args] - def fun(*arguments, arg_vals=arg_vals): + def fun(*arguments): arrays = [tvm.nd.empty(get_shape(out), out.dtype)] + \ [tvm.nd.array(a) for a in list(arguments) + arg_vals] mout(*arrays) From d90323eefc12745c04823c810736428c8d06486b Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 29 Jan 2019 16:23:42 +0300 Subject: [PATCH 05/11] Move autodiff.h to include/tvm --- {src/pass => include/tvm}/autodiff.h | 6 +++--- src/pass/autodiff.cc | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) rename {src/pass => include/tvm}/autodiff.h (98%) diff --git a/src/pass/autodiff.h b/include/tvm/autodiff.h similarity index 98% rename from src/pass/autodiff.h rename to include/tvm/autodiff.h index 43b0b8ebf23b..ade082d0b291 100644 --- a/src/pass/autodiff.h +++ b/include/tvm/autodiff.h @@ -3,8 +3,8 @@ * \file autodiff.h * \brief Automatic differentiation of IR Expr. */ -#ifndef TVM_PASS_AUTODIFF_H_ -#define TVM_PASS_AUTODIFF_H_ +#ifndef TVM_AUTODIFF_H_ +#define TVM_AUTODIFF_H_ #include #include @@ -143,4 +143,4 @@ EXPORT DifferentiationResult Differentiate(const Tensor& output, } // namespace ir } // namespace tvm -#endif // TVM_PASS_AUTODIFF_H_ +#endif // TVM_AUTODIFF_H_ diff --git a/src/pass/autodiff.cc b/src/pass/autodiff.cc index 5afc5d5476f5..5bbdd9691b1a 100644 --- a/src/pass/autodiff.cc +++ b/src/pass/autodiff.cc @@ -3,7 +3,7 @@ * \file autodiff.cc * \brief Automatic differentiation of IR Expr */ -#include "autodiff.h" +#include #include #include From 204f2f7dba5c5fbf344738fcf5d8c2e3f6466c3b Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 29 Jan 2019 16:28:23 +0300 Subject: [PATCH 06/11] Reduce the probability of test failure because of integer overflow --- tests/python/unittest/test_pass_zero_elimination.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py index 9eab3a9c6a87..e04abbddab5c 100644 --- a/tests/python/unittest/test_pass_zero_elimination.py +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -231,14 +231,14 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): for i in range(3): _check(2, 3) - for i in range(5): - _check(3, 3) - for i in range(5): - _check(3, 4) - # Somewhere here coefficients in the results become too large, leading to overflow, # so we use smaller initial coefficients + for i in range(5): + _check(3, 3, coef=(-2,2)) + for i in range(5): + _check(3, 4, coef=(-2,2)) + for i in range(5): _check(4, 3, coef=(-1,1)) From 952629b42917be09e7cef9064bdec4d68279530a Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Wed, 30 Jan 2019 18:52:54 +0300 Subject: [PATCH 07/11] Fix a problem with free vars --- python/tvm/testing.py | 30 ++-- src/pass/zero_elimination.cc | 52 +++++-- src/pass/zero_elimination.h | 10 +- tests/python/unittest/test_pass_autodiff.py | 140 +++++++++++------- .../unittest/test_pass_zero_elimination.py | 4 +- 5 files changed, 160 insertions(+), 76 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index e21dc5be6863..afdca6a19720 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -197,17 +197,20 @@ def __le__(self, other): self.memory <= other.memory -def estimate_performance(s, processed_ops=None): +def estimate_performance(s, param_values=None, processed_ops=None): """Statically estimate performance of statements, expressions and tensors. Note that the estimate is very rough, it mustn't be used to predict future performance, its only purpose is to detect possible performance regressions. - Parameters: - ----------- + Parameters + ---------- s A statement, an expression, a tensor, an operation, or a list of any of the above. + param_values : Dict[tvm.expr.Var, int], optional + Values for parameters (free variables). + Returns ------- estimate : PerformanceEstimate @@ -215,14 +218,23 @@ def estimate_performance(s, processed_ops=None): from tvm import stmt from tvm import expr + if param_values is None: + param_values = {} + if processed_ops is None: processed_ops = {} - res = estimate_performance(s, processed_ops) + res = estimate_performance(s, param_values=param_values, processed_ops=processed_ops) for op_est in processed_ops.values(): res += op_est return res - est = lambda e, processed_ops=processed_ops: estimate_performance(e, processed_ops) + def est(expression, param_values=param_values, processed_ops=processed_ops): + return estimate_performance(expression, + param_values=param_values, + processed_ops=processed_ops) + + def _eval(expression, param_values=param_values): + return tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expression, param_values)).value def _prod(elems): res = 1 @@ -241,7 +253,7 @@ def _prod(elems): elif s in processed_ops: return PerformanceEstimate() elif isinstance(s, stmt.Allocate): - mem = _prod([e.value for e in s.extents]) + mem = _prod([_eval(e) for e in s.extents]) return est(s.condition) + est(s.body) + PerformanceEstimate(memory=mem) elif isinstance(s, stmt.Block): return est(s.first) + est(s.rest) @@ -250,7 +262,7 @@ def _prod(elems): elif isinstance(s, stmt.For): body_est = est(s.body) body_est.iterations = max(1, body_est.iterations) - return body_est.times(s.extent.value) + return body_est.times(_eval(s.extent)) elif isinstance(s, stmt.IfThenElse): return est(s.condition) + est(s.then_case) + est(s.else_case) elif isinstance(s, stmt.LetStmt): @@ -289,7 +301,7 @@ def _prod(elems): elif isinstance(s, expr.Select): return est(s.condition) + est(s.true_value) + est(s.false_value) elif isinstance(s, expr.Reduce): - iterations = _prod([iv.dom.extent.value for iv in s.axis]) + iterations = _prod([_eval(iv.dom.extent) for iv in s.axis]) res = PerformanceEstimate() for id_elem in s.combiner.identity_element: res += est(id_elem) @@ -303,7 +315,7 @@ def _prod(elems): elif isinstance(s, tvm.tensor.Tensor): return est(s.op) elif isinstance(s, tvm.tensor.ComputeOp): - iterations = _prod([iv.dom.extent.value for iv in s.axis]) + iterations = _prod([_eval(iv.dom.extent) for iv in s.axis]) if s.reduce_axis: res = est(s.body[0]) else: diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 3a280b63d5af..2a053be81b46 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -240,6 +240,8 @@ Array IterVarsFromMap(const Array& vars, const Map& vr IterVarType iter_type = kDataPar, std::string thread_tag = "") { Array res; for (const Var& v : vars) { + CHECK(vranges.count(v)) << "A range for the variable " << v + << " was not provided in map " << vranges; res.push_back(IterVarNode::make(vranges[v], v, iter_type, thread_tag)); } return res; @@ -1198,7 +1200,7 @@ DomainSimplificationResult SimplifyDomain(const Expr& cond, new_var_intsets[new_var.get()] = IntSet::interval(make_zero(new_var.type()), diff); // Add the new var to the resulting axis - auto range = Range(make_zero(new_var.type()), diff + 1); + auto range = Range(make_zero(new_var.type()), SuperSimplify(diff + 1, vranges)); res.axis.push_back(new_var); res.ranges.Set(new_var, range); vranges.Set(new_var, range); @@ -1457,11 +1459,15 @@ std::pair LiftConditionsThroughReduction(const Expr& cond, class ExtractReductionsMutator : public IRMutator { public: - explicit ExtractReductionsMutator(Map vranges, std::string name = "extracted") - : vranges_(std::move(vranges)), name_(std::move(name)) {} + explicit ExtractReductionsMutator(const Array& outer_axis, + Map vranges, + std::string name = "extracted") + : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {} Expr Mutate_(const Reduce* op, const Expr& e) { - ExtractReductionsMutator new_mutator(Merge(vranges_, IterVarsToMap(op->axis)), name_); + ExtractReductionsMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_), + Merge(vranges_, IterVarsToMap(op->axis)), + name_); Array new_source; for (const Expr& src : op->source) { @@ -1470,7 +1476,18 @@ class ExtractReductionsMutator : public IRMutator { Expr new_reduce = Reduce::make(op->combiner, new_source, op->axis, op->condition, op->value_index); - Array vars = ExprFreeVars(new_reduce); + + ExprFreeVarsVisitor fv_visitor; + fv_visitor.Visit(new_reduce); + + // Vars of the tensor we are going to create for this reduction + Array vars; + for (const Var& v : outer_axis_) { + // We take variables from the outer_axis_ which are also present in the new reduction + if (fv_visitor.free.count(v.get())) { + vars.push_back(v); + } + } auto newaxis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_)); Array new_axis = newaxis_vmap_pair.first; @@ -1489,6 +1506,7 @@ class ExtractReductionsMutator : public IRMutator { } private: + Array outer_axis_; Map vranges_; std::string name_; std::string tag_; @@ -1496,23 +1514,28 @@ class ExtractReductionsMutator : public IRMutator { }; // Extract reductions as separate tensors. -Expr ExtractReductions(const Expr& expr, const Map& vranges) { - return ExtractReductionsMutator(vranges).Mutate(expr); +Expr ExtractReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges) { + return ExtractReductionsMutator(outer_axis, vranges).Mutate(expr); } -Expr ExtractNonTopReductions(const Expr& expr, const Map& vranges) { +Expr ExtractNonTopReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges) { if (const Reduce* red = expr.as()) { + Array new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis); Map new_vranges = Merge(vranges, IterVarsToMap(red->axis)); Array new_source; for (const Expr& src : red->source) { - new_source.push_back(ExtractReductions(src, new_vranges)); + new_source.push_back(ExtractReductions(src, new_outer_axis, new_vranges)); } - Expr new_condition = ExtractReductions(red->condition, new_vranges); + Expr new_condition = ExtractReductions(red->condition, new_outer_axis, new_vranges); return Reduce::make(red->combiner, new_source, red->axis, new_condition, red->value_index); } else { - return ExtractReductions(expr, vranges); + return ExtractReductions(expr, outer_axis, vranges); } } @@ -1591,7 +1614,8 @@ Expr OptimizeAndLiftNonzeronessConditionsImpl(const Expr& expr, const Array vrange = IterVarsToMap(axis); - return SuperSimplify(ExtractReductions(result, vrange), vrange); + return SuperSimplify(ExtractReductions(result, IterVarsToVars(axis), vrange), + vrange); } Tensor OptimizeAndLiftNonzeronessConditions(const Tensor& tensor) { @@ -1665,12 +1689,12 @@ TVM_REGISTER_API("ir_pass.ExtractAsTensorMaybe") TVM_REGISTER_API("ir_pass.ExtractReductions") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ExtractReductions(args[0], args[1]); + *ret = ExtractReductions(args[0], args[1], args[2]); }); TVM_REGISTER_API("ir_pass.ExtractNonTopReductions") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ExtractNonTopReductions(args[0], args[1]); + *ret = ExtractNonTopReductions(args[0], args[1], args[2]); }); TVM_REGISTER_API("ir_pass.OptimizeAndLiftNonzeronessConditions") diff --git a/src/pass/zero_elimination.h b/src/pass/zero_elimination.h index 26b2acc15b7a..1ac887bcb049 100644 --- a/src/pass/zero_elimination.h +++ b/src/pass/zero_elimination.h @@ -206,20 +206,26 @@ EXPORT Expr ExtractAsTensorMaybe(const Expr& expr, const Expr& cond, * are created. * * \param expr The expression from which to extract reductions. + * \param outer_axis Input variables of the enclosing tensor. * \param vranges Information about ranges of variables. * \return An expression without non-top-level reductions. */ -EXPORT Expr ExtractReductions(const Expr& expr, const Map& vranges); +EXPORT Expr ExtractReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges); /*! * \brief Extract reductions as separate tensors, but if the expr itself is a reduction, leave it * intact. * * \param expr The expression from which to extract reductions. + * \param outer_axis Input variables of the enclosing tensor. * \param vranges Information about ranges of variables. * \return An expression without non-top-level reductions. */ -EXPORT Expr ExtractNonTopReductions(const Expr& expr, const Map& vranges); +EXPORT Expr ExtractNonTopReductions(const Expr& expr, + const Array& outer_axis, + const Map& vranges); /*! * \brief Perform lifting of conditions of being possible to be non-zero together with diff --git a/tests/python/unittest/test_pass_autodiff.py b/tests/python/unittest/test_pass_autodiff.py index 42e83c58281c..06884f7f055d 100644 --- a/tests/python/unittest/test_pass_autodiff.py +++ b/tests/python/unittest/test_pass_autodiff.py @@ -9,8 +9,11 @@ # Whether to dump the generated code verbose = False -def get_shape(tensor): - return [tvm.ir_pass.Simplify(s).value for s in tensor.shape] +def get_shape(tensor, param_values=None): + if param_values is None: + param_values = {} + return [tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(s, param_values)).value + for s in tensor.shape] def check_equivalence(outputs1, outputs2, inputs, in_range=(-10, 10), iters=10): outputs1 = list(outputs1) @@ -37,12 +40,15 @@ def check_equivalence(outputs1, outputs2, inputs, in_range=(-10, 10), iters=10): for j, _ in enumerate(outputs1): tvm.testing.assert_allclose(arguments1[j].asnumpy(), arguments2[j].asnumpy()) -def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None): +def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None, param_values=None): line = inspect.getframeinfo(inspect.stack()[1][0]).lineno if not isinstance(inputs, (list, tuple)): inputs = [inputs] + if param_values is None: + param_values = {} + if verbose: print("\n" + 80*"=" + "\n") print("Testing gradients, line {}\n".format(line)) @@ -72,63 +78,66 @@ def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None): print(lowered) print() - est = estimate_performance(grads) - est_lowered = estimate_performance(lowered) - - if verbose: - print("Note: performance tuples are (iterations, multiplications, memory)") - print("Expected performance of grads: {}".format(perf)) - print("Estimated performance of grads: {}".format(est.as_tuple())) - print("Estimated performance of lowered grads: {}".format(est_lowered.as_tuple())) - print() - - if est_lowered.memory > est.memory: - print("WARNING: Line {}: The estimated memory consumption increased after lowering, " - "this may indicate that tensor bounds have been expanded too much".format(line)) - print("before: {} after: {}".format(est, est_lowered)) - - (iters, mults, mem) = est.as_tuple() - if perf is None or isinstance(perf, str): - print("WARNING: Line {}: No performance information, you may set it to {}" - .format(line, est.as_tuple())) - if isinstance(perf, str): - print("0,/{!r}/{{s/{!r}/{}/}}".format(perf, perf, (iters, mults, mem))) - elif perf != (iters, mults, mem): - (ref_iters, ref_mults, ref_mem) = perf - ref_est = PerformanceEstimate(*perf) - - if est <= ref_est: - print("WARNING: Line {}: Estimated performance {} is better than {}. " - "Use this with sed:" - .format(line, est.as_tuple(), ref_est.as_tuple())) - print("0,/{}/{{s/{}/{}/}}".format(perf, perf, (iters, mults, mem))) - elif est >= ref_est: - print("WARNING: Line {}: Estimated performance {} IS WORSE THAN {}" - .format(line, est.as_tuple(), ref_est.as_tuple())) - else: - print("WARNING: Line {}: Estimated performance {} does not match {}" - .format(line, est.as_tuple(), ref_est.as_tuple())) - - EST_RTOL = 1.5 - if iters > ref_iters*EST_RTOL or mults > ref_mults*EST_RTOL or mem > ref_mem*EST_RTOL: - raise AssertionError("Line {}: Some of the estimated performance metrics are much " - "worse than the reference ones (by {}): estimated {}, expected {}" - .format(line, EST_RTOL, est.as_tuple(), ref_est.as_tuple())) + if perf != False: + est = estimate_performance(grads, param_values=param_values) + est_lowered = estimate_performance(lowered, param_values=param_values) + + if verbose: + print("Note: performance tuples are (iterations, multiplications, memory)") + print("Expected performance of grads: {}".format(perf)) + print("Estimated performance of grads: {}".format(est.as_tuple())) + print("Estimated performance of lowered grads: {}".format(est_lowered.as_tuple())) + print() + + if est_lowered.memory > est.memory: + print("WARNING: Line {}: The estimated memory consumption increased after lowering, " + "this may indicate that tensor bounds have been expanded too much".format(line)) + print("before: {} after: {}".format(est, est_lowered)) + + (iters, mults, mem) = est.as_tuple() + if perf is None or isinstance(perf, str): + print("WARNING: Line {}: No performance information, you may set it to {}" + .format(line, est.as_tuple())) + if isinstance(perf, str): + print("0,/{!r}/{{s/{!r}/{}/}}".format(perf, perf, (iters, mults, mem))) + elif perf != (iters, mults, mem): + (ref_iters, ref_mults, ref_mem) = perf + ref_est = PerformanceEstimate(*perf) + + if est <= ref_est: + print("WARNING: Line {}: Estimated performance {} is better than {}. " + "Use this with sed:" + .format(line, est.as_tuple(), ref_est.as_tuple())) + print("0,/{}/{{s/{}/{}/}}".format(perf, perf, (iters, mults, mem))) + elif est >= ref_est: + print("WARNING: Line {}: Estimated performance {} IS WORSE THAN {}" + .format(line, est.as_tuple(), ref_est.as_tuple())) + else: + print("WARNING: Line {}: Estimated performance {} does not match {}" + .format(line, est.as_tuple(), ref_est.as_tuple())) + + EST_RTOL = 1.5 + if iters > ref_iters*EST_RTOL or mults > ref_mults*EST_RTOL or mem > ref_mem*EST_RTOL: + raise AssertionError("Line {}: Some of the estimated performance metrics are much " + "worse than the reference ones (by {}): " + "estimated {}, expected {}" + .format(line, EST_RTOL, est.as_tuple(), ref_est.as_tuple())) input_vals = [tvm.nd.array(np.random.uniform(in_range[0], in_range[1], - size=get_shape(a)).astype(a.dtype)) + size=get_shape(a, param_values)).astype(a.dtype)) for a in inputs] arg_vals = [tvm.nd.array(np.random.uniform(in_range[0], in_range[1], - size=get_shape(a)).astype(a.dtype)) + size=get_shape(a, param_values)).astype(a.dtype)) for a in args] def fun(*arguments): - arrays = [tvm.nd.empty(get_shape(out), out.dtype)] + \ + arrays = [tvm.nd.empty(get_shape(out, param_values), out.dtype)] + \ [tvm.nd.array(a) for a in list(arguments) + arg_vals] mout(*arrays) return arrays[0].asnumpy().sum() - g_arg_vals = [tvm.nd.empty(get_shape(i), g.dtype) for i, g in zip(inputs, grads)] + \ + g_arg_vals = \ + [tvm.nd.empty(get_shape(i, param_values), g.dtype) for i, g in zip(inputs, grads)] + \ input_vals + arg_vals mgrad(*g_arg_vals) g_res = [g_arg_vals[g].asnumpy() for g, _ in enumerate(grads)] @@ -226,7 +235,7 @@ def test_topi_autodiff(): X = tvm.placeholder((1, 2, 4, 4), name='X') W = tvm.placeholder((5, 2, 3, 3), name='W') W1 = tvm.placeholder((2, 5, 3, 3), name='W1') - W2 = tvm.placeholder((1,), name='W1') + W2 = tvm.placeholder((1,), name='W2') R = topi.nn.conv2d(X, W, 1, 1, 1) check_grad(R, [X, W], perf=(3410, 2880, 652)) @@ -410,6 +419,36 @@ def test_some_conv2d_net(): check_grad(t, weights, [x, y], in_range=(-1.0, 1.0), perf=(194865, 179089, 28194)) +def test_free_vars(): + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((n,), name='B') + + Y = topi.add(A, B) + check_grad(Y, [A, B], perf=(160, 0, 120), param_values={m: 5, n: 10}) + + param_values = {m: 10} + x = tvm.var("x", dtype='float32') + k = tvm.reduce_axis((0, m), name="k") + A0 = tvm.placeholder((m, m), name='A0') + A1 = tvm.placeholder((m, m), name='A1') + + B = tvm.compute((m, m), lambda i, j: A0[i, j] + A0[j, i], name='B') + check_grad(B, A0, perf=(10200, 10000, 300), param_values=param_values) + + B = tvm.compute((m,), lambda i: tvm.sum(A0[i, k]*A0[k, i], axis=k), name='B') + check_grad(B, A0, perf=(11110, 1000, 1210), param_values=param_values) + + X = tvm.placeholder((m, n, 4, 4), name='X') + param_values = {m: 1, n: 2} + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'avg') + check_grad(R, X, perf=(72, 224, 72), param_values=param_values) + + R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], 'max') + check_grad(R, X, perf=(200, 1248, 136), param_values=param_values) + if __name__ == "__main__": if "-v" in sys.argv: verbose = True @@ -419,3 +458,4 @@ def test_some_conv2d_net(): test_topi_autodiff() test_stride_dilation() test_some_conv2d_net() + test_free_vars() diff --git a/tests/python/unittest/test_pass_zero_elimination.py b/tests/python/unittest/test_pass_zero_elimination.py index e04abbddab5c..a1d4070a72f7 100644 --- a/tests/python/unittest/test_pass_zero_elimination.py +++ b/tests/python/unittest/test_pass_zero_elimination.py @@ -46,7 +46,7 @@ def compute(shape, fcompute): """Like tvm.compute but automatically extracts reductions.""" return tvm.compute(shape, lambda *vs: ExtractNonTopReductions( - fcompute(*vs), {v: tvm.Range(0, s) for v, s in zip(vs, shape)})) + fcompute(*vs), vs, {v: tvm.Range(0, s) for v, s in zip(vs, shape)})) def check_tensor_symeq(A, B): if not isinstance(B, tvm.tensor.Tensor): @@ -383,6 +383,7 @@ def test_extract_reductions(): A = tvm.compute((10, 10), lambda i, j: ExtractReductions(sum_combiner(i + k + xor_combiner(j*k + l, l), k), + [i, j], {i: tvm.Range(0, 10), j: tvm.Range(0, 10)})) B = tvm.compute((10, 10), lambda j, k: xor_combiner(j*k + l, l)) C = tvm.compute((10, 10), lambda i, j: sum_combiner(i + k + B[j, k], k)) @@ -391,6 +392,7 @@ def test_extract_reductions(): fcompute = lambda i, j: \ ExtractReductions(sum_both_combiner((prod_derivative_combiner((i*n + 2*k, j + k), k)[1], xor_combiner(j*n + l, l)), n)[1], + [i, j], {i: tvm.Range(0, 10), j: tvm.Range(0, 10)}) A = tvm.compute((10, 10), fcompute) _, B = tvm.compute((10, 10, 10), From 75418c7440de7de81ffd99db8273b1b511b4ef2d Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Thu, 14 Feb 2019 19:27:38 +0300 Subject: [PATCH 08/11] [AD] More intrinsics; Fixed treatment of ints --- python/tvm/testing.py | 34 ++++++++++++++------- src/pass/autodiff.cc | 25 ++++++++++++--- src/pass/zero_elimination.cc | 12 ++++++++ tests/python/unittest/test_pass_autodiff.py | 24 +++++++++++++-- 4 files changed, 78 insertions(+), 17 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index afdca6a19720..bca939881ea9 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -15,7 +15,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): def check_numerical_grads(function, input_values, grad_values, function_value=None, - delta=1e-3, atol=1e-2, rtol=0.1): + delta=1e-3, atol=1e-2, rtol=0.1, acceptable_fail_fraction=None): """A helper function that checks that numerical gradients of a function are equal to gradients computed in some different way (analytical gradients). @@ -51,6 +51,10 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No rtol : float, optional Relative tolerance. + + acceptable_fail_fraction : float, optional + If not None, raise an error only when the fraction of wrong elements for a gradient is + higher than this value. """ # If input_values is a list then function accepts positional arguments # In this case transform it to a function taking kwargs of the form {"0": ..., "1": ...} @@ -94,7 +98,7 @@ def compare_derivative(j, n_der, grad): wrong_positions = [] # compute partial derivatives for each position in this variable - for j in range(np.prod(grad.shape)): + for j in range(int(np.prod(grad.shape))): # forward difference approximation nder = derivative(x_name, j, delta) @@ -117,7 +121,7 @@ def compare_derivative(j, n_der, grad): ngrad.reshape(-1)[j] = nder - wrong_percentage = int(100*len(wrong_positions)/np.prod(grad.shape)) + wrong_fraction = len(wrong_positions)/np.prod(grad.shape) dist = np.sqrt(np.sum((ngrad - grad)**2)) grad_norm = np.sqrt(np.sum(ngrad**2)) @@ -132,14 +136,22 @@ def compare_derivative(j, n_der, grad): sqrt_n = np.sqrt(float(np.prod(grad.shape))) if dist > atol*sqrt_n + rtol*grad_norm: - raise AssertionError( - "Analytical and numerical grads wrt '{}' differ too much\n" - "analytical grad = {}\n numerical grad = {}\n" - "{}% of elements differ, first 10 of wrong positions: {}\n" - "distance > atol*sqrt(n) + rtol*grad_norm\n" - "distance {} > {}*{} + {}*{}" - .format(x_name, grad, ngrad, wrong_percentage, wrong_positions[:10], - dist, atol, sqrt_n, rtol, grad_norm)) + enough_failures = (acceptable_fail_fraction is None or + wrong_fraction > acceptable_fail_fraction) + if enough_failures: + raise AssertionError( + "Analytical and numerical grads wrt '{}' differ too much\n" + "analytical grad = {}\n numerical grad = {}\n" + "{}% of elements differ, first 10 of wrong positions: {}\n" + "distance > atol*sqrt(n) + rtol*grad_norm\n" + "distance {} > {}*{} + {}*{}" + .format(x_name, grad, ngrad, int(100*wrong_fraction), wrong_positions[:10], + dist, atol, sqrt_n, rtol, grad_norm)) + else: + logging.warning("Analytical and numerical grads wrt '%s' differ, however " + "there were not enough wrong elements to raise an error " + "(only %d%%)", + x_name, int(100*wrong_fraction)) max_diff = np.max(np.abs(ngrad - grad)) avg_diff = np.mean(np.abs(ngrad - grad)) diff --git a/src/pass/autodiff.cc b/src/pass/autodiff.cc index 5bbdd9691b1a..05f843a164c4 100644 --- a/src/pass/autodiff.cc +++ b/src/pass/autodiff.cc @@ -46,8 +46,7 @@ TVM_REGISTER_NODE_TYPE(DifferentiationResultNode); #define NOT_IMPLEMENTED \ - { CHECK(false) << "Derivative of this op is not implemented"; \ - throw dmlc::Error("Derivative of this op is not implemented"); } + { LOG(FATAL) << "Derivative of this expr is not implemented: " << e; throw; } /*! \brief Differentiate an expression wrt a variable or a tensor element */ class JacobianMutator : public IRMutator { @@ -66,8 +65,17 @@ class JacobianMutator : public IRMutator { explicit JacobianMutator(VarExpr input) : input_var_(input) {} + virtual Expr Mutate(Expr e) { + if (e.type().is_int() || e.type().is_uint()) { + // Assume that the derivative of any integer expression is always 0 + return make_zero(e.type()); + } else { + return IRMutator::Mutate(e); + } + } + Expr Mutate_(const Variable* op, const Expr& e) { - if (input_var_.operator->() && input_var_.get() == op) { + if (input_var_.operator->() && input_var_.get() == op && op->type.is_float()) { return FloatImm::make(op->type, 1.0); } else { return make_zero(op->type); @@ -90,6 +98,7 @@ class JacobianMutator : public IRMutator { return make_zero(op->type); } } else if (op->call_type == Call::CallType::PureIntrinsic) { + static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; if (op->name == "exp") { return Mul::make(Mutate(op->args[0]), e); } else if (op->name == "log") { @@ -97,9 +106,14 @@ class JacobianMutator : public IRMutator { } else if (op->name == "sigmoid") { return Mul::make(Mutate(op->args[0]), Mul::make(e, Sub::make(FloatImm::make(e.type(), 1.0), e))); + } else if (op->name == "sqrt") { + return Div::make(Mutate(op->args[0]), Mul::make(e, FloatImm::make(e.type(), 2.0))); } else if (op->name == "tanh") { return Mul::make(Mutate(op->args[0]), Sub::make(FloatImm::make(e.type(), 1.0), Mul::make(e, e))); + } else if (op->name == "pow") { + auto x = op->args[0], y = op->args[1]; + return e * (Mutate(y)*log(x) + Mutate(x)*y/x); } else if (op->name == "fabs") { auto type = op->args[0].type(); return Mul::make(Mutate(op->args[0]), @@ -108,6 +122,8 @@ class JacobianMutator : public IRMutator { } else if (op->name == intrinsic::tvm_if_then_else) { Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; return Call::make(op->type, op->name, new_args, op->call_type, op->func, op->value_index); + } else if (piecewise_const.count(op->name)) { + return FloatImm::make(e.type(), 0.0); } else { throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); } @@ -329,7 +345,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input, bool optimize) { return tensor; } else { - NOT_IMPLEMENTED; + LOG(FATAL) << "Derivative of this op is not implemented: " << output->op; + throw; } } diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 2a053be81b46..247ea17b2379 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -382,6 +382,13 @@ class NonzeronessConditionFunctor result_type VisitExpr_(const Mod* op, const Expr& e) final { return BinOpDivLike_(op, e); } result_type VisitExpr_(const Min* op, const Expr& e) final { return BinOpAddLike_(op, e); } result_type VisitExpr_(const Max* op, const Expr& e) final { return BinOpAddLike_(op, e); } + result_type VisitExpr_(const EQ* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const NE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const LE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const LT* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const GE* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const GT* op, const Expr& e) final { return Bool_(op, e); } + result_type VisitExpr_(const Not* op, const Expr& e) final { return Bool_(op, e); } result_type VisitExpr_(const Cast* op, const Expr& e) final { if (op->value.type().is_bool()) { @@ -492,6 +499,11 @@ class NonzeronessConditionFunctor return {nz_a.cond, TNode::make(nz_a.value, op->b)}; } } + + template + NonzeronessConditionResult Bool_(const TNode* op, const Expr& e) { + return {e, make_const(e.type(), 1)}; + } }; NonzeronessConditionResult NonzeronessCondition(const Expr& expr) { diff --git a/tests/python/unittest/test_pass_autodiff.py b/tests/python/unittest/test_pass_autodiff.py index 06884f7f055d..645f54374de6 100644 --- a/tests/python/unittest/test_pass_autodiff.py +++ b/tests/python/unittest/test_pass_autodiff.py @@ -40,7 +40,8 @@ def check_equivalence(outputs1, outputs2, inputs, in_range=(-10, 10), iters=10): for j, _ in enumerate(outputs1): tvm.testing.assert_allclose(arguments1[j].asnumpy(), arguments2[j].asnumpy()) -def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None, param_values=None): +def check_grad(out, inputs, args=[], in_range=(-10,10), perf=None, param_values=None, + acceptable_fail_fraction=None): line = inspect.getframeinfo(inspect.stack()[1][0]).lineno if not isinstance(inputs, (list, tuple)): @@ -142,7 +143,8 @@ def fun(*arguments): mgrad(*g_arg_vals) g_res = [g_arg_vals[g].asnumpy() for g, _ in enumerate(grads)] - check_numerical_grads(fun, [a.asnumpy() for a in input_vals], g_res) + check_numerical_grads(fun, [a.asnumpy() for a in input_vals], g_res, + acceptable_fail_fraction=acceptable_fail_fraction) def test_differentiate_function(): x = tvm.placeholder((32, 3, 28, 28), name='x') @@ -182,6 +184,18 @@ def test_autodiff(): B = tvm.compute((10, 10), lambda i, j: A0[i, j] + A0[j, i], name='B') check_grad(B, A0, perf=(10100, 10000, 200)) + B = tvm.compute((10, 10), lambda i, j: tvm.floor(A0[i, j]), name='B') + check_grad(B, A0, perf=(100, 0, 100), acceptable_fail_fraction=0.05) + + B = tvm.compute((10, 10), lambda i, j: tvm.ceil(A0[i, j]), name='B') + check_grad(B, A0, perf=(100, 0, 100), acceptable_fail_fraction=0.05) + + B = tvm.compute((10, 10), lambda i, j: tvm.trunc(A0[i, j]), name='B') + check_grad(B, A0, perf=(100, 0, 100), acceptable_fail_fraction=0.05) + + B = tvm.compute((10, 10), lambda i, j: tvm.round(A0[i, j]), name='B') + check_grad(B, A0, perf=(100, 0, 100), acceptable_fail_fraction=0.05) + B = tvm.compute((10, 10), lambda i, j: A0[i, j] + tvm.exp(A0[j, i]), name='B') check_grad(B, A0, perf=(10100, 20000, 200)) @@ -194,6 +208,12 @@ def test_autodiff(): B = tvm.compute((10, 10), lambda i, j: tvm.tanh(A0[i, j]*A0[i, j]*A0[j, i]), name='B') check_grad(B, A0, perf=(10100, 120000, 200)) + B = tvm.compute((10, 10), lambda i, j: tvm.sqrt(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + check_grad(B, A0, perf=(10100, 90000, 200), in_range=(0.1, 10)) + + B = tvm.compute((10, 10), lambda i, j: tvm.power(tvm.abs(A0[i, j]), A0[j, i]), name='B') + check_grad(B, A0, perf=(10100, 90000, 200)) + B = tvm.compute((10, 10), lambda i, j: A0[i, j] * A0[j, i], name='B') check_grad(B, A0, perf=(10100, 10000, 200)) From cf5083a81072c72cedd14433d6ce90e744b216f7 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Thu, 7 Feb 2019 14:45:46 +0300 Subject: [PATCH 09/11] [AD] Autodiff/relay integration --- src/op/op_util.cc | 28 +++ src/op/op_util.h | 9 + src/relay/op/autodiff_integration.cc | 212 ++++++++++++++++++++ tests/python/relay/test_primal_gradients.py | 191 ++++++++++++++++++ 4 files changed, 440 insertions(+) create mode 100644 src/relay/op/autodiff_integration.cc create mode 100644 tests/python/relay/test_primal_gradients.py diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 4231f336a01b..f80f5f1eaabb 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -206,6 +206,34 @@ Expr ReplaceTensor(Expr expr, } +void ReplaceTensorRecursivelyImpl(Tensor tensor, + std::unordered_map* replace) { + if (!replace->count(tensor)) { + for (const Tensor& subtensor : tensor->op->InputTensors()) { + ReplaceTensorRecursivelyImpl(subtensor, replace); + } + Operation new_op = tensor->op->ReplaceInputs(tensor->op, *replace); + if (new_op.same_as(tensor->op)) { + (*replace)[tensor] = tensor; + } else { + (*replace)[tensor] = + TensorNode::make(tensor->shape, tensor->dtype, new_op, tensor->value_index); + } + } +} + +Array ReplaceTensorRecursively(Array tensors, + const std::unordered_map& replace) { + auto new_replace = replace; + Array res; + for (const Tensor& t : tensors) { + ReplaceTensorRecursivelyImpl(t, &new_replace); + res.push_back(new_replace[t]); + } + return res; +} + + Stmt Substitute(Stmt s, const std::unordered_map& value_map) { std::unordered_map init; diff --git a/src/op/op_util.h b/src/op/op_util.h index da7987f7162f..f8cebe229112 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -64,6 +64,15 @@ Stmt ReplaceTensor(Stmt stmt, Expr ReplaceTensor(Expr expr, const std::unordered_map& replace); +/*! + * \brief Replace tensor references in the given tensors recursively (not only in their bodies + * but also in the bodies of its dependencies). + * \param tensors The tensors to be processed. + * \param replace The replacement rule. + */ +Array ReplaceTensorRecursively(Array tensors, + const std::unordered_map& replace); + /*! * \brief Substitute the variables of stmt by value map. * \param stmt the statment diff --git a/src/relay/op/autodiff_integration.cc b/src/relay/op/autodiff_integration.cc new file mode 100644 index 000000000000..bec601169ceb --- /dev/null +++ b/src/relay/op/autodiff_integration.cc @@ -0,0 +1,212 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file autodiff_integration.cc + * \brief Integration with autodiff for TVM tensor expressions. + */ + +#include +#include +#include +#include +#include "./type_relations.h" +#include "./op_common.h" +#include "../../op/op_util.h" + +namespace tvm { +namespace relay { + +/*! \brief Attributes for the automatically generated gradient operation. */ +struct AutogeneratedGradientAttrs : public tvm::AttrsNode { + Op original_op; + Attrs original_attrs; + Type original_out_type; + + TVM_DECLARE_ATTRS(AutogeneratedGradientAttrs, "relay.attrs.AutogeneratedGradientAttrs") { + TVM_ATTR_FIELD(original_op) + .describe("The original operation."); + TVM_ATTR_FIELD(original_attrs) + .describe("The attributes of the original operation."); + TVM_ATTR_FIELD(original_out_type).set_default(Type(nullptr)) + .describe("The type of the original expression."); + } +}; + +bool AutogeneratedGradientRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + const AutogeneratedGradientAttrs* real_attrs = attrs.as(); + CHECK(real_attrs != nullptr) << "Attrs are null or have an invalid type."; + + // There are just two types: the type of the input tuple and the type of the output tuple. + CHECK(types.size() == 2) << "The size of the types array must be 2, not " << types.size(); + const auto* tuple_type = types[0].as(); + CHECK(tuple_type != nullptr) << "The input must be a tuple, not " << types[0]; + // The input tuple contains the original inputs and the last item is the adjoint + // for the output of the original operation. + Array input_types(tuple_type->fields.begin(), tuple_type->fields.end() + (-1)); + // The output of the gradient operation is a containing values of the same types as the + // original inputs. + reporter->Assign(types[1], TupleTypeNode::make(input_types)); + return true; +} + +Array AutogeneratedGradientCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + static auto fcompute = Op::GetAttr("FTVMCompute"); + + const AutogeneratedGradientAttrs* real_attrs = attrs.as(); + CHECK(real_attrs != nullptr); + + // We need the type of the original output to pass it to the + // FTVMCompute of the original operation. + Type original_out_type = real_attrs->original_out_type; + + // The `inputs` array contains both the original inputs and the adjoint, both in the + // flattened form. In general, the adjoint may consist of several tensors, so we need to know + // the number of the output tensors of the original operation. + size_t num_orig_outputs = 1; + // NOTE: Here we assume that there are no nested tuples + if (const auto* tuple_type = original_out_type.as()) { + num_orig_outputs = tuple_type->fields.size(); + } else if (const auto* tuple_type = out_type.as()) { + // Guess the number of outputs of the original op from the number of inputs of the original + // op (which is the same as the number of outputs of this gradient node). + num_orig_outputs = inputs.size() - tuple_type->fields.size(); + } + + CHECK(inputs.size() >= num_orig_outputs); + + // If the original output type hasn't been preserved, try to reconstruct it using the + // number of original outputs. + if (!original_out_type.defined()) { + Array fields; + for (auto it = inputs.end() + (-num_orig_outputs); it != inputs.end(); ++it) { + fields.push_back(TensorTypeNode::make((*it)->shape, (*it)->dtype)); + } + if (num_orig_outputs == 1) { + // If the number of the outputs is 1 then the output type is probably just a tensor, not + // a tuple of a single element. + original_out_type = fields[0]; + } else { + original_out_type = TupleTypeNode::make(fields); + } + } + + Array original_inputs(inputs.begin(), inputs.end() + (-num_orig_outputs)); + Array adjoints(inputs.end() + (-num_orig_outputs), inputs.end()); + + // In theory the inputs might contain duplicate entries which won't agree with the automatic + // differentiation, so we create new placeholders which we will replace with the inputs later. + Array input_placeholders; + std::unordered_map placeholders_to_inputs; + for (const Tensor& input : original_inputs) { + Tensor place = + tvm::PlaceholderOpNode::make(input->op->name, input->shape, input->dtype).output(0); + input_placeholders.push_back(place); + placeholders_to_inputs[place] = input; + } + + Array forward = + fcompute[real_attrs->original_op](real_attrs->original_attrs, input_placeholders, + original_out_type, target); + + CHECK(forward.size() == adjoints.size()); + + // If there are multiple outputs, we have to propagate gradients from all of them and + // add up the results. Note that there may be suboptimality, in the future we might want + // to make the Differentiate function accept arrays of outputs. + Array res; + for (size_t i = 0; i < forward.size(); ++i) { + Array part = + tvm::ir::Differentiate(forward[i], input_placeholders, adjoints[i])->result; + part = tvm::op::ReplaceTensorRecursively(part, placeholders_to_inputs); + + if (i == 0) { + res = part; + } else { + for (size_t j = 0; j < res.size(); ++j) { + res.Set(j, topi::add(res[j], part[j])); + } + } + } + + return res; +} + +RELAY_REGISTER_OP("autogenerated_gradient") +.describe(R"doc(Gradients for any specified operation generated using the automatic differentiation +for tensor expressions. + +- **input**: A tuple of the form `(x1, ..., xn, g)` where `x1, ..., xn` are the inputs of the + original operation, and g is the gradient of the loss with respect to the output + of the original operation. +- **out**: A tuple of the form `(g1, ..., gn)` containing the gradients of the loss with respect to + the inputs of the original operation. +)doc") +.set_num_inputs(1) +.add_argument("input", "Tuple", "A tuple containing the original inputs and the adjoint.") +.set_attrs_type_key("relay.attrs.AutogeneratedGradientAttrs") +.add_type_rel("AutogeneratedGradient", AutogeneratedGradientRel) +.set_attr("FTVMCompute", AutogeneratedGradientCompute) +.set_attr("TOpPattern", kOpaque) +.set_attr("FTVMSchedule", + [](const Attrs& attrs, const Array& outs, const Target& target) { + Array out_ops; + for (auto t : outs) + out_ops.push_back(t->op); + return create_schedule(out_ops); + }); + +FPrimalGradient AutogeneratedFPrimalGradient(const Op& op) { + return [op](const Expr& orig, const Expr& adjoint) -> Array { + const CallNode* call = orig.as(); + CHECK(call != nullptr); + + auto attrs = make_node(); + attrs->original_op = op; + attrs->original_attrs = call->attrs; + if (call->checked_type_.defined()) { + attrs->original_out_type = call->checked_type(); + } + + Array args_in_tuple = call->args; + args_in_tuple.push_back(adjoint); + Array args = {TupleNode::make(args_in_tuple)}; + auto grad_call = CallNode::make(Op::Get("autogenerated_gradient"), args, Attrs(attrs)); + + Array res; + for (size_t i = 0; i < call->args.size(); ++i) { + res.push_back(TupleGetItemNode::make(grad_call, i)); + } + return res; + }; +} + +/*! \brief Automatically generate primal gradient for the given operation. */ +void AutogeneratePrimalGradient(const std::string& op_name, int plevel = 100) { + OpRegistry& opreg = relay::OpRegistry::Registry()->__REGISTER_OR_GET__(op_name); + Op op = opreg.op(); + opreg.set_attr("FPrimalGradient", AutogeneratedFPrimalGradient(op), plevel); +} + +/*! \brief Automatically generate primal gradients for all operations in the registry. */ +void AutogeneratePrimalGradientForAll(int plevel = 5) { + for (const OpRegistry* opreg : relay::OpRegistry::Registry()->List()) { + AutogeneratePrimalGradient(opreg->op()->name, plevel); + } +} + +TVM_REGISTER_API("relay._ir_pass.AutogeneratePrimalGradient") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { + AutogeneratePrimalGradient(args[0]); + }); +TVM_REGISTER_API("relay._ir_pass.AutogeneratePrimalGradientForAll") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { + AutogeneratePrimalGradientForAll(); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_primal_gradients.py b/tests/python/relay/test_primal_gradients.py new file mode 100644 index 000000000000..d23a920d578b --- /dev/null +++ b/tests/python/relay/test_primal_gradients.py @@ -0,0 +1,191 @@ +import tvm +import numpy as np + +from tvm import relay + +def to_int_array(arr, param_values=None): + if param_values is None: + param_values = {} + return [tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(s, param_values)).value + for s in arr] + +def tvm2numpy(something): + if isinstance(something, (tvm.ndarray.NDArray, relay.backend.interpreter.TensorValue)): + return something.asnumpy() + elif isinstance(something, list): + return [tvm2numpy(s) for s in something] + elif isinstance(something, (tuple, relay.backend.interpreter.TupleValue)): + return tuple(tvm2numpy(s) for s in something) + return something + +def check_relay_grad(expr, in_range=(-10,10), acceptable_fail_fraction=None): + expr = relay.ir_pass.infer_type(expr) + if len(expr.checked_type.shape) != 0: + expr = relay.op.sum(expr) + + input_vars = relay.ir_pass.free_vars(expr) + func = relay.Function(input_vars, expr) + func = relay.ir_pass.infer_type(func) + + gfunc = relay.ir_pass.gradient(func) + gfunc = relay.ir_pass.infer_type(gfunc) + + executor = relay.create_executor() + tvm_func = executor.evaluate(func) + tvm_gfunc = executor.evaluate(gfunc) + + np_func = lambda *a: tvm2numpy(tvm_func(*a)) + np_gfunc = lambda *a: tvm2numpy(tvm_gfunc(*a)) + + input_vals = [np.random.uniform(in_range[0], in_range[1], + size=to_int_array(a.type_annotation.shape)) + .astype(a.type_annotation.dtype) + for a in input_vars] + + tvm.testing.check_numerical_grads(np_func, input_vals, np_gfunc(*input_vals)[1], + acceptable_fail_fraction=acceptable_fail_fraction) + +def test_autogenerated_primal_gradients(): + relay._ir_pass.AutogeneratePrimalGradientForAll(100) + + x = relay.var("x", shape=(5,), dtype='float64') + y = relay.var("y", shape=(5,), dtype='float64') + + check_relay_grad(x*x) + check_relay_grad(x*y) + check_relay_grad(x*y + x + y) + + k = relay.var("k", shape=()) + x = relay.var("x", shape=(5,)) + y = relay.var("y", shape=(5,)) + ix = relay.var("ix", shape=(5,), dtype='int32') + iy = relay.var("iy", shape=(5,), dtype='int32') + ind = relay.var("ind", shape=(10,), dtype='int32') + x1 = relay.var("x1", shape=(5,)) + y1 = relay.var("y1", shape=(5,)) + X = relay.var("X", shape=(5, 7)) + Y = relay.var("Y", shape=(5, 10)) + Y1 = relay.var("Y1", shape=(1, 2, 5, 5)) + Y2 = relay.var("Y2", shape=(3, 2, 5, 5)) + W = relay.var("W", shape=(7, 10)) + A = relay.var("A", shape=(2, 5, 7, 7)) + w = relay.var("w", shape=(4, 5, 3, 3)) + w1 = relay.var("w1", shape=(5, 4, 3, 3)) + + check_relay_grad(x*x) + check_relay_grad(x*y) + check_relay_grad(x*y + x + y) + + check_relay_grad(relay.op.log(x), in_range=(0.1, 10)) + check_relay_grad(relay.op.sqrt(x), in_range=(0.1, 10)) + check_relay_grad(relay.op.exp(x)) + check_relay_grad(relay.op.sigmoid(x)) + check_relay_grad(relay.op.add(x, y)) + check_relay_grad(relay.op.subtract(x, y)) + check_relay_grad(relay.op.multiply(x, y)) + check_relay_grad(relay.op.divide(x, y)) + #check_relay_grad(relay.op.mod(x, y)) + check_relay_grad(relay.op.tanh(x)) + #check_relay_grad(relay.op.concatenate([x, y], 0)) + #check_relay_grad(relay.op.concatenate([X, Y], 1)) + check_relay_grad(relay.op.expand_dims(X, 1, 1)) + check_relay_grad(relay.op.expand_dims(X, 2, 3)) + check_relay_grad(relay.nn.softmax(x)) + check_relay_grad(relay.nn.log_softmax(X)) + check_relay_grad(relay.nn.relu(x)) + #check_relay_grad(relay.nn.dropout(x)) + #check_relay_grad(relay.nn.batch_norm(A, x, y, x1, y1)[0]) + check_relay_grad(relay.nn.bias_add(X, x, 0)) + + check_relay_grad(relay.nn.conv2d(A, w)) + check_relay_grad(relay.nn.conv2d(A, w, strides=(2, 1))) + check_relay_grad(relay.nn.conv2d(A, w, padding=(1, 0))) + check_relay_grad(relay.nn.conv2d(A, w, dilation=(1, 2))) + check_relay_grad(relay.nn.conv2d_transpose(A, w1)) + check_relay_grad(relay.nn.conv2d_transpose(A, w1, strides=(2, 1))) + check_relay_grad(relay.nn.conv2d_transpose(A, w1, padding=(1, 0))) + #check_relay_grad(relay.nn.conv2d_transpose(A, w1, dilation=(1, 2))) + check_relay_grad(relay.nn.dense(X, W)) + check_relay_grad(relay.nn.max_pool2d(A)) + check_relay_grad(relay.nn.max_pool2d(A, pool_size=(2, 2))) + check_relay_grad(relay.nn.max_pool2d(A, pool_size=(2, 2), strides=(2, 1))) + check_relay_grad(relay.nn.max_pool2d(A, pool_size=(3, 3), strides=(3, 2), padding=(1, 1))) + check_relay_grad(relay.nn.avg_pool2d(A)) + check_relay_grad(relay.nn.avg_pool2d(A, pool_size=(2, 2))) + check_relay_grad(relay.nn.avg_pool2d(A, pool_size=(2, 2), strides=(2, 1))) + check_relay_grad(relay.nn.avg_pool2d(A, pool_size=(3, 3), strides=(3, 2), padding=(1, 1))) + check_relay_grad(relay.nn.global_max_pool2d(A)) + check_relay_grad(relay.nn.global_avg_pool2d(A)) + check_relay_grad(relay.nn.upsampling(A, scale=2)) + check_relay_grad(relay.nn.batch_flatten(A)) + check_relay_grad(relay.nn.pad(A, ((0, 0), (0, 1), (1, 2), (2, 3)))) + check_relay_grad(relay.nn.lrn(A)) + check_relay_grad(relay.nn.l2_normalize(A, 0.01, axis=[1])) + #check_relay_grad(relay.nn.contrib_conv2d_winograd_without_weight_transform(...)) + check_relay_grad(relay.nn.contrib_conv2d_winograd_weight_transform(w, 4)) + + check_relay_grad(relay.nn.leaky_relu(x, 0.1)) + check_relay_grad(relay.nn.prelu(A, x)) + check_relay_grad(relay.reshape(Y, (2, 5, 5))) + check_relay_grad(relay.reshape_like(Y, Y1)) + check_relay_grad(relay.copy(x)) + check_relay_grad(relay.transpose(Y)) + check_relay_grad(relay.squeeze(Y1)) + check_relay_grad(relay.floor(x), acceptable_fail_fraction=0.2) + check_relay_grad(relay.ceil(x), acceptable_fail_fraction=0.2) + check_relay_grad(relay.trunc(x), acceptable_fail_fraction=0.2) + #check_relay_grad(relay.clip(x, -2, 2)) + check_relay_grad(relay.round(x), acceptable_fail_fraction=0.2) + check_relay_grad(relay.abs(x)) + check_relay_grad(relay.negative(x)) + #check_relay_grad(relay.take(x, ind), in_range=(0, 4)) + check_relay_grad(relay.zeros((5, 6), 'float32')) + check_relay_grad(relay.zeros_like(x)) + check_relay_grad(relay.ones((5, 6), 'float32')) + check_relay_grad(relay.ones_like(x)) + check_relay_grad(relay.full(k, (5, 6), 'float32')) + check_relay_grad(relay.full_like(x, k)) + #check_relay_grad(relay.cast(x, 'float64')) + #check_relay_grad(relay.split(x, (1, 3))) + + #check_relay_grad(relay.right_shift(ix, iy)) + #check_relay_grad(relay.left_shift(ix, iy)) + #check_relay_grad(relay.equal(ix, iy)) + #check_relay_grad(relay.not_equal(ix, iy)) + #check_relay_grad(relay.greater(x, y)) + #check_relay_grad(relay.greater_equal(x, y)) + #check_relay_grad(relay.less(x, y)) + #check_relay_grad(relay.less_equal(x, y)) + check_relay_grad(relay.maximum(x, y)) + check_relay_grad(relay.minimum(x, y)) + check_relay_grad(relay.power(relay.abs(x), y)) + #check_relay_grad(relay.where(ix, x, y)) + #check_relay_grad(relay.where(relay.greater(x, y), x, y)) + #check_relay_grad(relay.argmax(x)) + #check_relay_grad(relay.argmin(X, axis=1)) + check_relay_grad(relay.sum(x)) + check_relay_grad(relay.max(x)) + check_relay_grad(relay.min(x)) + check_relay_grad(relay.mean(x)) + check_relay_grad(relay.prod(x)) + check_relay_grad(relay.strided_slice(A, (0, 4, 2, 0), (1, 1, 5, 6), (1, -1, 2, 3))) + check_relay_grad(relay.broadcast_to(Y1, (3, 2, 5, 5))) + + check_relay_grad(relay.image.resize(A, (12, 10), method='BILINEAR')) + check_relay_grad(relay.image.resize(A, (12, 10), method='BILINEAR', align_corners=True)) + #check_relay_grad(relay.image.resize(A, (12, 10), method='NEAREST_NEIGHBOR')) + #check_relay_grad(relay.vision.multibox_prior(A)) + #check_relay_grad(relay.vision.multibox_transform_loc(...)) + #check_relay_grad(relay.vision.nms(...)) + + check_relay_grad(relay.broadcast_to_like(Y1, Y2)) + check_relay_grad(relay.collapse_sum_like(X, x)) + t1 = relay.var("t1", shape=(3, 4, 5)) + t2 = relay.var("t2", shape=(1, 2, 3)) + check_relay_grad(relay.slice_like(t1, t2)) + check_relay_grad(relay.layout_transform(w1, 'NCHW', 'NHCW2c')) + #check_relay_grad(relay.device_copy(...)) + #check_relay_grad(relay.annotation.on_device(...)) + +if __name__ == "__main__": + test_autogenerated_primal_gradients() From 5ae4ac1c352de4e893997c8d316eb5822cc563f9 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 5 Feb 2019 13:50:07 +0300 Subject: [PATCH 10/11] Simplified overriding of gradients; Tutorial and docs --- docs/api/python/dev.rst | 5 + docs/api/python/index.rst | 1 + docs/api/python/topi.rst | 2 + include/tvm/autodiff.h | 14 +- python/tvm/autodiff.py | 158 ++++++++++++-------- src/pass/autodiff.cc | 30 +++- tests/python/unittest/test_pass_autodiff.py | 17 ++- tutorials/language/autodiff_basics.py | 129 ++++++++++++++++ 8 files changed, 278 insertions(+), 78 deletions(-) create mode 100644 tutorials/language/autodiff_basics.py diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index 3a1804f37d1e..f2ee86faaaa4 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -70,3 +70,8 @@ tvm.make ~~~~~~~~ .. automodule:: tvm.make :members: + +tvm.testing +~~~~~~~~~~~ +.. automodule:: tvm.testing + :members: diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index ddad9d10f8f9..b8361f2aaed9 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -15,6 +15,7 @@ Python API container function autotvm + autodiff graph_runtime rpc bridge diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 856bad198e88..b6e6df43f498 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -67,6 +67,7 @@ List of operators topi.not_equal topi.greater_equal topi.less_equal + topi.tensordot topi.image.resize @@ -123,6 +124,7 @@ topi .. autofunction:: topi.power .. autofunction:: topi.greater .. autofunction:: topi.less +.. autofunction:: topi.tensordot topi.nn ~~~~~~~ diff --git a/include/tvm/autodiff.h b/include/tvm/autodiff.h index ade082d0b291..15ab436ede0b 100644 --- a/include/tvm/autodiff.h +++ b/include/tvm/autodiff.h @@ -129,6 +129,10 @@ EXPORT Tensor DiffBuildingBlock(const Tensor& output, const Tensor& input, const * `output.shape + output.shape` will be used. * \param fdiff The function performing differentiation and multiplication, see * ::FDiffBuildingBlock. + * \param override_deps A map from tensors to their dependencies (`InputTensors()` are used by + * default). Overriding dependencies may be useful to treat a group of tensors + * as a single supertensor. In this case the fdiff functions should also be + * modified accordingly. * \return An object of type DifferentiationResult which contains three fields: * - `result` An array of adjoints corresponding to \p inputs. * - `adjoints` A map from tensors to the corresponding adjoints (includes intermediate @@ -136,10 +140,12 @@ EXPORT Tensor DiffBuildingBlock(const Tensor& output, const Tensor& input, const * - `adjoint_summands` A map from tensors to maps from parent tensors to individual * summands of the adjoint. */ -EXPORT DifferentiationResult Differentiate(const Tensor& output, - const Array& inputs = Array(), - const Tensor& head = Tensor(), - const FDiffBuildingBlock& fdiff = DiffBuildingBlock); +EXPORT DifferentiationResult Differentiate( + const Tensor& output, + const Array& inputs = Array(), + const Tensor& head = Tensor(), + const FDiffBuildingBlock& fdiff = DiffBuildingBlock, + const Map>& override_deps = Map>()); } // namespace ir } // namespace tvm diff --git a/python/tvm/autodiff.py b/python/tvm/autodiff.py index 63364c42cad0..1d496185eb76 100644 --- a/python/tvm/autodiff.py +++ b/python/tvm/autodiff.py @@ -1,10 +1,4 @@ -"""Namespace of autodiff-related functions. - -The functions are automatically exported from C++ side via PackedFunc. -You can read "include/tvm/autodiff.h" for the function signature of these functions. -""" -import logging - +"""Automatic differentiation of tensor expressions.""" from ._ffi.function import _init_api from ._ffi.node import NodeBase, register_node @@ -16,52 +10,56 @@ class DifferentiationResult(NodeBase): Parameters ---------- - result : list of Tensor + result : List[Tensor] The requested adjoints, i.e. the jacobians or gradients of the given output wrt to the given inputs. - adjoints : dict from Tensor to Tensor + adjoints : Dict[Tensor, Tensor] A map from tensors to the corresponding adjoints (including internal nodes). - adjoint_summands : dict from Tensor to dict from Tensor to Tensor + adjoint_summands : Dict[Tensor, Dict[Tensor, Tensor]] Single summands of the adjoints. """ - def __getattr__(self, name): - # Here we convert tvm Maps to dicts because Map compares keys by reference which is - # wrong for Tensors. Hopefully, in the future Map gets fixed somehow, and this function - # may be removed then. - res = NodeBase.__getattr__(self, name) - if name == 'adjoints': - return dict(res.items()) - if name == 'adjoint_summands': - return {k: dict(v.items()) for k, v in res.items()} - return res + + # Here we convert tvm Maps to dicts because Map compares keys by reference which is + # wrong for Tensors. Hopefully, in the future Map gets fixed somehow, and these properties + # may be removed then. + + @property + def adjoints(self): + res = NodeBase.__getattr__(self, 'adjoints') + return dict(res.items()) + + @property + def adjoint_summands(self): + res = NodeBase.__getattr__(self, 'adjoint_summands') + return {k: dict(v.items()) for k, v in res.items()} + + def _check_not_empty(self): + if not self.result: + raise ValueError("The result of differentiation does not contain any explicitly " + "requested results, so using it as an iterable is probably a mistake. " + "Please explicitly use res.adjoints to get adjoints or res.result to " + "get the empty list.") def __getitem__(self, i): + self._check_not_empty() return self.result[i] def __len__(self): + self._check_not_empty() return len(self.result) -def differentiate(output, inputs=None, head=None, manual=None, fdiff=None): +def differentiate(output, inputs=None, head=None, override=None, fdiff=None): """Perform reverse-mode automatic differentiation. - Example:: - - x = tvm.placeholder((32, 3, 28, 28), name='x') - w1 = tvm.placeholder((10, 3, 3, 3), name='w1') - w2 = tvm.placeholder((10, 10, 3, 3), name='w2') - y = topi.sum(topi.nn.conv2d(topi.nn.conv2d(x, w1, 1, 0), w2, 1, 0)) - - [dw1, dw2] = tvm.differentiate(y, [w1, w2]) - Parameters ---------- output : Tensor The tensor to differentiate. - inputs : list of Tensor + inputs : List[Tensor] The list of input tensors. When the list is empty or None, will perform differentiation wrt all tensors the output depends on (i.e. will compute all adjoints and populate the corresponding dict, but the list of results @@ -73,15 +71,18 @@ def differentiate(output, inputs=None, head=None, manual=None, fdiff=None): If `None` is passed, the identity tensor of shape `output.shape + output.shape` will be used. - manual : dict (Tensor, Tensor) -> function - A dict providing custom multiplication-differentiation functions (see `fdiff`) - for certain pairs of tensors. Each pair consists of an output and an input tensor, - the input one being an immediate dependency of the output one. Pairs of the form - `(None, tensor)` and `(tensor, None)` are allowed, `None` working as a wildcard. + override : Dict[Tensor, (List[Tensor], Callable[[Tensor, List[Tensor], Tensor], List[Tensor]])] + Override differentiation for certain tensors. This dict maps tensors `t` to pairs + `(dependencies, custom_diff)` where `dependencies` is a list of tensors which are considered + to be inputs of `t` (which may differ from the immediate inputs), and `custom_diff` is a + custom differentiation function which will be called as `custom_diff(t, dependencies, + adjoint)` and should return a list of adjoints corresponding to dependencies. Note that this + function differs from the one required for `fdiff` in that it takes a list of inputs instead + of a single input and returns a list of adjoints instead of a single adjoint. - fdiff : function (Tensor, Tensor, Tensor) -> Tensor + fdiff : Callable[[Tensor, Tensor, Tensor], Tensor] The default function performing differentiation and multiplication, by default - `tvm.autodiff.FDiffBuildingBlock` is used. The function must accept three + `tvm.autodiff.DiffBuildingBlock` is used. The function must accept three parameters: - `output` - an output tensor - `input` - an input tensor @@ -91,6 +92,52 @@ def differentiate(output, inputs=None, head=None, manual=None, fdiff=None): Returns ------- differentiation_result : DifferentiationResult + + Example + ------- + .. code-block:: python + + x = tvm.placeholder((32, 3, 28, 28), name='x') + w1 = tvm.placeholder((10, 3, 3, 3), name='w1') + w2 = tvm.placeholder((10, 10, 3, 3), name='w2') + z1 = topi.nn.conv2d(x, w1, 1, 1, 1) + z2 = topi.nn.conv2d(z1, w2, 1, 1, 1) + y = topi.sum(z2) + + # produce gradients + [dw1, dw2] = tvm.differentiate(y, [w1, w2]) + + # produce Jacobians + [jw1, jw2] = tvm.differentiate(z2, [w1, w2]) + + # produce gradients, the head adjoint for z2 is provided manually + [dw1, dw2] = tvm.differentiate(z2, [w1, w2], topi.full_like(z2, 1.0)) + + # produce gradients wrt all inputs + res = tvm.differentiate(y) + dw1 = res.adjoints[w1] + dw2 = res.adjoints[w2] + + # a custom differentiation function + def my_fdiff(out, inp, head): + # this is the naive version, without any optimizations + return topi.tensordot(head, tvm.autodiff.Jacobian(out, inp, False), len(out.shape)) + + # using a custom differentiation function for everything + [dw1, dw2] = tvm.differentiate(y, [w1, w2], fdiff=my_fdiff) + + # accessing individual summands of the adjoint + y = z1 + z2 + res = tvm.differentiate(y, [w1, w2]) + [s1, s2] = res.adjoint_summands[z1].values() + + # a generalization of my_fdiff which works for non-immediate dependencies + # this is necessary because z1 is not an immediate dep of z2 because of padding + def my_diff(out, inputs, head): + return tvm.differentiate(out, inputs, head, fdiff=my_fdiff) + + # using a custom differentiation function only for z2 + res = tvm.differentiate(y, [w1, w2], override={z2: ([z1, w2], my_diff)}) """ if inputs is None: inputs = [] @@ -98,33 +145,18 @@ def differentiate(output, inputs=None, head=None, manual=None, fdiff=None): if fdiff is None: fdiff = DiffBuildingBlock - if manual is not None: - if not isinstance(manual, dict): - manual = dict(manual) - + if override is not None: # pylint: disable=dangerous-default-value - used_items = set() - - def _modified_fdiff(out, inp, head, manual=manual, old_fdiff=fdiff, used_items=used_items): - if (out, inp) in manual: - used_items.add((out, inp)) - return manual[(out, inp)](out, inp, head) - if (out, None) in manual: - used_items.add((out, None)) - return manual[(out, None)](out, inp, head) - if (None, inp) in manual: - used_items.add((None, inp)) - return manual[(None, inp)](out, inp, head) + def _modified_fdiff(out, inp, head, override=override, old_fdiff=fdiff, cache={}): + if out in override: + if (out, head) not in cache: + cache[(out, head)] = override[out][1](out, override[out][0], head) + idx = override[out][0].index(inp) + return cache[(out, head)][idx] return old_fdiff(out, inp, head) fdiff = _modified_fdiff - res = Differentiate(output, inputs, head, fdiff) - - if manual is not None: - for k in manual: - if k not in used_items: - logging.warning("The manually specified differentiation function " - "for %s hasn't been used", k) - - return res + override_deps = {t: deps for t, (deps, _) in override.items()} + return Differentiate(output, inputs, head, fdiff, override_deps) + return Differentiate(output, inputs, head, fdiff) diff --git a/src/pass/autodiff.cc b/src/pass/autodiff.cc index 05f843a164c4..ffd6386529e6 100644 --- a/src/pass/autodiff.cc +++ b/src/pass/autodiff.cc @@ -282,6 +282,16 @@ Expr Derivative(const Expr& expr, const VarExpr& var) { Tensor Jacobian(const Tensor& output, const Tensor& input, bool optimize) { if (const ComputeOpNode* op = output->op.as()) { + bool is_input_tensor = false; + for (const Tensor& child : op->InputTensors()) { + if (input == child) { + is_input_tensor = true; + break; + } + } + CHECK(is_input_tensor) << "Jacobian is called on a pair of tensors such that the output " + << "does not depend on the input. This is probably a mistake."; + // We have to clone the iteration axes because otherwise the original expression // cannot be used together with the derivative (it will lead to errors during lowering) Array new_axis; @@ -365,7 +375,8 @@ Tensor DiffBuildingBlock(const Tensor& output, const Tensor& input, const Tensor DifferentiationResult Differentiate(const Tensor& output, const Array& inputs, const Tensor& head_or_null, - const FDiffBuildingBlock& fdiff) { + const FDiffBuildingBlock& fdiff, + const Map>& override_deps) { Tensor head = head_or_null; // If the head is a null pointer, create an identity tensor @@ -389,13 +400,22 @@ DifferentiationResult Differentiate(const Tensor& output, // bodies) std::unordered_map> reverse_dependencies; + // Map doesn't work correctly for Tensors, so convert it to std::unordered_map + std::unordered_map> override_deps_map; + for (auto pair : override_deps) { + override_deps_map.insert(pair); + } + // Collect reverse dependencies std::vector stack({output}); while (!stack.empty()) { Tensor tensor = stack.back(); stack.pop_back(); - for (const Tensor& child : tensor->op->InputTensors()) { + auto it = override_deps_map.find(tensor); + Array deps = it != override_deps_map.end() ? it->second : tensor->op->InputTensors(); + + for (const Tensor& child : deps) { if (!reverse_dependencies.count(child)) { stack.push_back(child); } @@ -507,7 +527,11 @@ TVM_REGISTER_API("autodiff.Differentiate") [pfunc](const Tensor& o, const Tensor& i, const Tensor& h) { return pfunc(o, i, h); }; - *ret = Differentiate(args[0], args[1], args[2], fdiff); + if (args.size() >= 5) { + *ret = Differentiate(args[0], args[1], args[2], fdiff, args[4]); + } else { + *ret = Differentiate(args[0], args[1], args[2], fdiff); + } } }); diff --git a/tests/python/unittest/test_pass_autodiff.py b/tests/python/unittest/test_pass_autodiff.py index 645f54374de6..70d42dee6f14 100644 --- a/tests/python/unittest/test_pass_autodiff.py +++ b/tests/python/unittest/test_pass_autodiff.py @@ -15,7 +15,7 @@ def get_shape(tensor, param_values=None): return [tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(s, param_values)).value for s in tensor.shape] -def check_equivalence(outputs1, outputs2, inputs, in_range=(-10, 10), iters=10): +def check_equivalence(outputs1, outputs2, inputs, in_range=(-10, 10), iters=3): outputs1 = list(outputs1) outputs2 = list(outputs2) sched1 = tvm.create_schedule([o.op for o in outputs1]) @@ -160,17 +160,18 @@ def test_differentiate_function(): check_equivalence([dx1, dw1], [dx2, dw2], [x, w]) - def mydiff(out, inp, head): - return tvm.compute(inp.shape, - lambda ax0, ax1, ax2, ax3: head[ax0, ax3 + ax2*26 + ax1*676]) + def mydiff(out, inp, head, t1=t1, t2=t2): + assert out == t2 and inp == [t1] + return [tvm.compute(t1.shape, + lambda ax0, ax1, ax2, ax3: head[ax0, ax3 + ax2*26 + ax1*676])] - res = tvm.differentiate(t3, [x, w], manual={(t2, t1): mydiff}) + res = tvm.differentiate(t3, [x, w], override={t2: ([t1], mydiff)}) check_equivalence(res.result, [dx1, dw1], [x, w]) - res = tvm.differentiate(t3, [x, w], manual={(t2, None): mydiff}) - check_equivalence(res.result, [dx1, dw1], [x, w]) + def mydiff2(out, inputs, head): + return tvm.differentiate(out, inputs, head) - res = tvm.differentiate(t3, [x, w], manual={(None, t1): mydiff}) + res = tvm.differentiate(t3, [x, w], override={t1: ([x, w], mydiff2)}) check_equivalence(res.result, [dx1, dw1], [x, w]) # Test some simple expressions diff --git a/tutorials/language/autodiff_basics.py b/tutorials/language/autodiff_basics.py new file mode 100644 index 000000000000..916f4eefad6d --- /dev/null +++ b/tutorials/language/autodiff_basics.py @@ -0,0 +1,129 @@ +""" +Automatic Differentiation of Tensor Expressions +=============================================== +**Author**: `Sergei Grechanik `_ + +This tutorial describes how to use automatic differentiation of tensor expressions. + +Usually differentiation is done on the level of NNVM/Relay graphs. However there are some situations +when one might want to perform differentiation on the lower level of TVM tensor expressions, e.g.: + - When you are experimenting with a completely new kind of operations. + - When gradients for some operations haven't been implemented yet in NNVM/Relay. + - When you are implementing gradients for a new operation manually and need a starting point. + - When you want to train models in pure TVM without NNVM/Relay (if you really do, please tell us + why). + +.. note:: + + - Automatic differentiation is still work in progress. Some operations are not differentiated + very well yet. + - Automatic differentiation doesn't perform scheduling. The generated code should be scheduled + by hand or using some autoscheduling and autotuning methods (which may require manually + writing schedule templates). + +""" +from __future__ import absolute_import, print_function +import tvm +import topi + +###################################################################### +# How to use automatic differentiation +# ------------------------------------ +# +# Basically, all you need is the function :any:`tvm.differentiate` (also known as +# :any:`tvm.autodiff.differentiate`) which takes a tensor, differentiates it with respect to other +# given tensors using reverse accumulation, and applies certain optimizations. Let's consider an +# example: + +# inputs +X = tvm.placeholder((32, 100), name='X') +W = tvm.placeholder((10, 100), name='W') +B = tvm.placeholder((10,), name='B') + +# forward computation, basically topi.nn.dense(X, W, B) +k = tvm.reduce_axis((0, 100)) +T = tvm.compute((32, 10), lambda i, j: tvm.sum(X[i, k]*W[j, k], k)) +Y = topi.add(T, B) +L = topi.sum(Y) + +# gradients +[dL_dW, dL_dB] = tvm.differentiate(L, [W, B]) + +###################################################################### +# `L` is a scalar, so the results are gradients, however in general the result is a full Jacobian. +# :any:`tvm.differentiate` also accepts the third parameter if you want to multiply the Jacobian by +# another tensor. + +[dY_dW] = tvm.differentiate(Y, [W]) +print("Y.shape", Y.shape) +print("W.shape", W.shape) +print("dY_dW.shape", dY_dW.shape) + +[dL_dW] = tvm.differentiate(Y, [W], topi.full_like(Y, 1.0)) + +###################################################################### +# The result of :any:`tvm.differentiate` mimics a list, however it is an object that also contains +# all intermediate adjoints. Note also that the list of input tensors may be omitted, in which case +# the output will be differentiated with respect to all the inputs: + +res = tvm.differentiate(L) +dL_dX = res.adjoints[X] +dL_dT = res.adjoints[T] +dL_dY = res.adjoints[Y] + +###################################################################### +# Overriding the differentiation function +# --------------------------------------- +# +# :any:`tvm.differentiate` internally calls a function which performs differentiation of a given +# tensor with respect to one of its inputs. This functions may be overridden for every tensor or for +# some particular tensors, which is useful when the default differentiation function does a poor job +# and we need to provide some gradients manually. Let's define our own naive version of this +# function: + +def custom_fdiff(out, inp, head): + return topi.tensordot(head, tvm.autodiff.Jacobian(out, inp, False), len(out.shape)) + +###################################################################### +# This function must take the tensors `out`, `inp` and `head` where `out` is the tensor that should +# be differentiated with respect to `inp`, `inp` is an immediate dependency of `out`, and `head` is +# the adjoint of `out` which should be multiplied by the result of differentiation. The +# differentiation itself is done using the function :any:`tvm.autodiff.Jacobian`, and the +# multiplication is done with :any:`topi.tensordot`. The default differentiation function +# :any:`tvm.autodiff.DiffBuildingBlock` does the same thing, but it also applies certain optimizing +# transformations. +# +# A custom differentiation function may be used like this: + +res = tvm.differentiate(L, fdiff=custom_fdiff) + +###################################################################### +# A custom differentiation function may be used to override differentiation for certain operations +# by checking if `out` is the operation we want to differentiate differently. However, there is an +# alternative way: using the `override` keyword argument. `override` should be a dict mapping +# tensors to their dependencies and custom differentiation functions. +# +# Let's consider the following scenario: we want to block gradient flow from `Y` to `X` and compute +# gradients of `Y` wrt `B` and `W` using the unoptimized differentiation function `custom_fdiff`. +# Note that `W` and `X` are not immediate dependencies of `Y`. + +def custom_fdiff_2(out, inputs, head): + assert out == Y + assert inputs == [X, W, B] + # block gradients to X + dX = topi.full(head.shape[:-len(out.shape)] + list(X.shape), head.dtype, 0) + # use the custom unoptimized differentiation function for the rest + return [dX] + list(tvm.differentiate(out, [W, B], head, fdiff=custom_fdiff)) + +res = tvm.differentiate(L, override={Y: ([X, W, B], custom_fdiff_2)}) + +###################################################################### +# There are several things to note: +# - For efficiency reasons the custom differentiation function used in `override` has a slightly +# different interface than the custom differentiation functions used for `fdiff`, namely it +# takes a list of inputs instead of a single input, and returns the list of the corresponding +# adjoints. +# - We had overridden the dependencies for `Y` (its immediate dependencies are `T` and `B`, but we +# used `X`, `W` and `B` instead), so we couldn't use :any:`tvm.autodiff.Jacobian` or +# `custom_fdiff` directly, since they expect the input to be an immediate dependency for the +# output. That's why we had to wrap them in the call to :any:`tvm.differentiate`. From a342e5e93fa067b135f918cc0d9f2fef67dba1bd Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 19 Feb 2019 19:01:18 +0300 Subject: [PATCH 11/11] Several fixes --- include/tvm/autodiff.h | 42 +++++++++++++-------------- src/pass/autodiff.cc | 2 +- src/pass/zero_elimination.cc | 5 ++-- src/relay/op/autodiff_integration.cc | 3 -- tutorials/language/autodiff_basics.py | 31 ++++++++++++++++++++ 5 files changed, 55 insertions(+), 28 deletions(-) diff --git a/include/tvm/autodiff.h b/include/tvm/autodiff.h index 15ab436ede0b..d9f47009af2a 100644 --- a/include/tvm/autodiff.h +++ b/include/tvm/autodiff.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file autodiff.h - * \brief Automatic differentiation of IR Expr. + * \brief Automatic differentiation of tensor expressions. */ #ifndef TVM_AUTODIFF_H_ #define TVM_AUTODIFF_H_ @@ -12,24 +12,7 @@ namespace tvm { namespace ir { -class DifferentiationResultNode; - -/*! - * \brief A result of differentiation. - */ -class DifferentiationResult : public NodeRef { - public: - /*! \brief default constructor, used internally */ - DifferentiationResult() {} - explicit DifferentiationResult(NodePtr n) : NodeRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const DifferentiationResultNode* operator->() const; - /*! \brief specify container node */ - using ContainerType = DifferentiationResultNode; -}; +class DifferentiationResult; /*! \brief Node to represent a differentiation result */ class DifferentiationResultNode : public Node { @@ -56,9 +39,24 @@ class DifferentiationResultNode : public Node { TVM_DECLARE_NODE_TYPE_INFO(DifferentiationResultNode, Node); }; -inline const DifferentiationResultNode* DifferentiationResult::operator->() const { - return static_cast(node_.get()); -} +/*! + * \brief A result of differentiation. + */ +class DifferentiationResult : public NodeRef { + public: + /*! \brief default constructor, used internally */ + DifferentiationResult() {} + explicit DifferentiationResult(NodePtr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const DifferentiationResultNode* operator->() const { + return static_cast(node_.get()); + } + /*! \brief specify container node */ + using ContainerType = DifferentiationResultNode; +}; /*! \brief A type of a "local" differentiation function for reverse mode AD diff --git a/src/pass/autodiff.cc b/src/pass/autodiff.cc index ffd6386529e6..2ee806dfb7db 100644 --- a/src/pass/autodiff.cc +++ b/src/pass/autodiff.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file autodiff.cc - * \brief Automatic differentiation of IR Expr + * \brief Automatic differentiation of tensor expressions */ #include diff --git a/src/pass/zero_elimination.cc b/src/pass/zero_elimination.cc index 247ea17b2379..56e4006c824a 100644 --- a/src/pass/zero_elimination.cc +++ b/src/pass/zero_elimination.cc @@ -1298,7 +1298,8 @@ Expr ExtractAsTensorMaybe(const Expr& e, const Expr& cond, return e; } - Tensor tensor = op::TensorFromExpr(new_expr, IterVarsFromMap(res.axis, res.ranges)); + Tensor tensor = op::TensorFromExpr(new_expr, IterVarsFromMap(res.axis, res.ranges), + "extracted_tensor"); Array args; for (const Var& var : res.axis) { @@ -1473,7 +1474,7 @@ class ExtractReductionsMutator : public IRMutator { public: explicit ExtractReductionsMutator(const Array& outer_axis, Map vranges, - std::string name = "extracted") + std::string name = "extracted_reduction") : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {} Expr Mutate_(const Reduce* op, const Expr& e) { diff --git a/src/relay/op/autodiff_integration.cc b/src/relay/op/autodiff_integration.cc index bec601169ceb..a86d504e8798 100644 --- a/src/relay/op/autodiff_integration.cc +++ b/src/relay/op/autodiff_integration.cc @@ -35,9 +35,6 @@ bool AutogeneratedGradientRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - const AutogeneratedGradientAttrs* real_attrs = attrs.as(); - CHECK(real_attrs != nullptr) << "Attrs are null or have an invalid type."; - // There are just two types: the type of the input tuple and the type of the output tuple. CHECK(types.size() == 2) << "The size of the types array must be 2, not " << types.size(); const auto* tuple_type = types[0].as(); diff --git a/tutorials/language/autodiff_basics.py b/tutorials/language/autodiff_basics.py index 916f4eefad6d..32646a8952b2 100644 --- a/tutorials/language/autodiff_basics.py +++ b/tutorials/language/autodiff_basics.py @@ -71,6 +71,37 @@ dL_dT = res.adjoints[T] dL_dY = res.adjoints[Y] +###################################################################### +# Examples of generated gradients +# ------------------------------- +# +# Let's print out some generated code. We'll start with the simple matrix multiplication +# we've already differentiated. + +T1 = tvm.compute((32, 10), lambda i, j: tvm.sum(X[i, k]*W[j, k], k), name='matmul') +H1 = tvm.placeholder(T1.shape, name='H1') + +[dW] = tvm.differentiate(T1, [W], H1) +print(tvm.PrintTensorRecursively(dW)) + +###################################################################### +# (The only problem here is that an unnecessary intermediate tensor was extracted.) +# +# Now let's look at some problematic operations, like maxpool: + +X1 = tvm.placeholder((64, 32, 28, 28), name='X1') +W1 = tvm.placeholder((64, 64, 3, 3), name='W1') +Y1 = topi.nn.pool(X1, [2, 2], [2, 2], [0, 0, 0, 0], 'max') +H1 = tvm.placeholder(Y1.shape, name='H1') + +[dX1] = tvm.differentiate(Y1, [X1], H1) +print(tvm.PrintTensorRecursively(dX1)) + +###################################################################### +# Here the elements of the adjoint `H1` are multiplied by the elements of a mask (computed with +# the tensor called `extracted_tensor`). The mask represents whether an element is the maximum of +# its neighborhood. This is not the optimal solution. + ###################################################################### # Overriding the differentiation function # ---------------------------------------