Skip to content

Commit

Permalink
[TE][Fix] Comparison of the output tensor (#9829)
Browse files Browse the repository at this point in the history
* [TE][Fix] Comparison of the output tensor

* fix hybrid op issue

* fix tensor replacement in schedule ops

* fix compute inline
  • Loading branch information
leeexyz authored Feb 20, 2022
1 parent b445d66 commit 73cf51b
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 21 deletions.
8 changes: 6 additions & 2 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ class TVM_DLL OperationNode : public Object {
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief additional attributes of the operation*/
/*! \brief additional attributes of the operation */
Map<String, ObjectRef> attrs;
/*! \brief output tensors */
Array<Tensor> outputs;

// virtual destructor.
virtual ~OperationNode() {}
/*! \return number of outputs */
Expand Down Expand Up @@ -473,7 +476,7 @@ class HybridOpNode : public OperationNode {
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
Array<Tensor> symbolic_outputs;
/*! \brief The axis of iterations */
Array<IterVar> axis;
/*! \brief the statement that generates the computation. This is
Expand Down Expand Up @@ -509,6 +512,7 @@ class HybridOpNode : public OperationNode {
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("symbolic_outputs", &symbolic_outputs);
v->Visit("axis", &axis);
v->Visit("body", &body);
}
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def visit_Assign(self, node):
"You should bind a pure name to the tensors",
)
self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
rmap[rhs.outputs[i].op] = rhs.output(i)
rmap[rhs.symbolic_outputs[i].op] = rhs.output(i)
return utils.replace_io(rhs.body, rmap)

_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/te/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def __eq__(self, other):
if isinstance(other, _expr.ExprOp):
return _expr.EqualOp(self, other)
return False
if self.same_as(other):
return True
if self.ndim == 0 and other.ndim == 0:
raise ValueError(
"Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison"
"use Tensor.equal for content expression equvalence."
)
return _ffi_api.TensorEqual(self, other)

Expand Down
1 change: 1 addition & 0 deletions src/te/operation/extern_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Operation ExternOpNode::ReplaceInputs(const Operation& self,
ICHECK_EQ(self.operator->(), this);
auto n = make_object<ExternOpNode>(*this);
n->body = ReplaceTensor(this->body, rmap);
n->outputs = Array<Tensor>();
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
Expand Down
11 changes: 6 additions & 5 deletions src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

TVM_REGISTER_NODE_TYPE(HybridOpNode);

int HybridOpNode::num_outputs() const { return static_cast<int>(outputs.size()); }
int HybridOpNode::num_outputs() const { return static_cast<int>(symbolic_outputs.size()); }

Array<IterVar> HybridOpNode::root_iter_vars() const { return this->axis; }

DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; }
DataType HybridOpNode::output_dtype(size_t i) const { return symbolic_outputs[i]->dtype; }

Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; }
Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return symbolic_outputs[i]->shape; }

HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
Expand All @@ -67,7 +67,7 @@ HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> att
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->inputs = std::move(inputs);
n->outputs = std::move(outputs);
n->symbolic_outputs = std::move(outputs);
n->axis = te::GatherLoopVars(body);
n->body = std::move(body);
data_ = std::move(n);
Expand Down Expand Up @@ -104,6 +104,7 @@ Operation HybridOpNode::ReplaceInputs(const Operation& self,
ICHECK_EQ(self.operator->(), this);
auto n = make_object<HybridOpNode>(*this);
n->body = te::ReplaceTensor(this->body, rmap);
n->outputs = Array<Tensor>();
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
Expand Down Expand Up @@ -166,7 +167,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage,
Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
rmap[symbolic_outputs[i]] = stage->op.output(i);
}
auto n = make_object<HybridOpNode>(*this);
/* This is a story little bit complicated.
Expand Down
2 changes: 2 additions & 0 deletions src/te/operation/scan_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ScanOpNode);

int ScanOpNode::num_outputs() const { return static_cast<int>(update.size()); }

Array<IterVar> ScanOpNode::root_iter_vars() const {
Array<IterVar> ret{scan_axis};
for (IterVar iv : spatial_axis_) {
Expand Down Expand Up @@ -143,6 +144,7 @@ Operation ScanOpNode::ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
ICHECK_EQ(self.operator->(), this);
auto n = make_object<ScanOpNode>(*this);
n->outputs = Array<Tensor>();
for (size_t i = 0; i < n->init.size(); ++i) {
if (rmap.count(n->init[i])) {
n->init.Set(i, rmap.at(n->init[i]));
Expand Down
1 change: 1 addition & 0 deletions src/te/operation/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Operation TensorComputeOpNode::ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
ICHECK_EQ(self.operator->(), this);
auto n = make_object<TensorComputeOpNode>(*this);
n->outputs = Array<Tensor>();
auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
intrin->body = ReplaceTensor(this->intrin->body, rmap);
if (intrin->reduce_init.defined()) {
Expand Down
2 changes: 1 addition & 1 deletion src/te/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
ICHECK(hybrid);
Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
hybrid->outputs, new_hybrid_body[i]);
hybrid->symbolic_outputs, new_hybrid_body[i]);
op = op->ReplaceInputs(op, repl);
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
Expand Down
12 changes: 9 additions & 3 deletions src/te/schedule/schedule_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,15 @@ class SchedulePostProc : public StmtExprMutator {
private:
void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(),
Operation repl_op = Operation()) {
replace_buffer_[src] = dst;
replace_realize_[src] = repl_realize;
replace_op_[src->op.get()] = repl_op;
if (!src.same_as(dst)) {
replace_buffer_[src] = dst;
}
if (!src.same_as(repl_realize)) {
replace_realize_[src] = repl_realize;
}
if (!src->op.same_as(repl_op)) {
replace_op_[src->op.get()] = repl_op;
}
}
// The thread extent scope.
std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
Expand Down
23 changes: 16 additions & 7 deletions src/te/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,22 @@ String TensorNode::GetNameHint() const {
return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index));
}

Tensor Operation::output(size_t i) const {
auto node = make_object<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
Tensor Operation::output(size_t n) const {
// cache the output tensors if empty
if ((*this)->outputs.empty()) {
auto* ptr = static_cast<OperationNode*>(get_mutable());
size_t num = static_cast<size_t>((*this)->num_outputs());
for (size_t i = 0; i < num; ++i) {
auto node = make_object<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
ptr->outputs.push_back(Tensor(node));
}
}
ICHECK_LT(n, (*this)->outputs.size());
return (*this)->outputs[n];
}

Tensor::Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_te_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def test_tensor():
assert T.op.output(0).__hash__() == T.__hash__()
d = {T.op.output(0): 1}
assert d[T] == 1
assert T == T.op.output(0)
assert T.same_as(T.op.output(0))
assert T[0][0][0].astype("float16").dtype == "float16"


Expand All @@ -49,6 +51,8 @@ def test_rank_zero():
print(T)
print(T.op.body)
assert tuple(T.shape) == ()
assert T == T.op.output(0)
assert T.same_as(T.op.output(0))


def test_conv1d():
Expand Down

0 comments on commit 73cf51b

Please sign in to comment.