Skip to content

Commit

Permalink
[Topi] Breakdown topi.cc into smaller files (apache#5253)
Browse files Browse the repository at this point in the history
* [Topi] Breakdown topi.cc into smaller files

* add missing file
  • Loading branch information
icemelon authored and dpankratz committed Apr 24, 2020
1 parent f92fcee commit f67af3a
Show file tree
Hide file tree
Showing 10 changed files with 1,127 additions and 896 deletions.
6 changes: 3 additions & 3 deletions topi/include/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ namespace topi {
* \return A Tensor whose op member is a broadcast operation
*/
inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
const tvm::Array<tvm::PrimExpr>& output_shape,
std::string name = "T_broadcast_to",
std::string tag = kBroadcast) {
const tvm::Array<tvm::PrimExpr>& output_shape,
std::string name = "T_broadcast_to",
std::string tag = kBroadcast) {
CHECK_GE(output_shape.size(), t->shape.size())
<< "Not a broadcast, output dimensionality smaller than input.\noutput: "
<< output_shape << "\nvs\ninput: " << t;
Expand Down
53 changes: 53 additions & 0 deletions topi/include/topi/util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief Topi utility function
* \file topi/util.h
*/
#ifndef TOPI_UTIL_H_
#define TOPI_UTIL_H_

#include <tvm/ir/expr.h>
#include <tvm/runtime/packed_func.h>

namespace topi {

using namespace tvm;
using namespace tvm::runtime;

/*! \brief Canonicalize an argument that may be Array<Expr> or int to Array<Expr> */
inline Array<Integer> ArrayOrInt(TVMArgValue arg) {
if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) {
Array<Integer> result;
result.push_back(arg.operator int());
return result;
} else {
return arg;
}
}

inline bool IsTensorType(TVMArgValue arg) {
return (arg.type_code() == kTVMObjectHandle &&
static_cast<Object*>(
arg.value().v_handle)->IsInstance<tvm::te::TensorNode>());
}

} // namespace topi
#endif // TOPI_UTIL_H_
84 changes: 84 additions & 0 deletions topi/src/broadcast.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief Registration of broadcast operators
* \file broadcast.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <topi/broadcast.h>
#include <topi/util.h>

namespace topi {

using namespace tvm;
using namespace tvm::runtime;

#define TOPI_REGISTER_BCAST_OP(OpName, Op) \
TVM_REGISTER_GLOBAL(OpName) \
.set_body([](TVMArgs args, TVMRetValue *rv) { \
bool lhs_is_tensor = IsTensorType(args[0]); \
bool rhs_is_tensor = IsTensorType(args[1]); \
if (lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::te::Tensor(), \
args[1].operator tvm::te::Tensor()); \
} else if (!lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::PrimExpr(), \
args[1].operator tvm::te::Tensor()); \
} else if (lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::te::Tensor(), \
args[1].operator tvm::PrimExpr()); \
} else if (!lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::PrimExpr(), \
args[1].operator tvm::PrimExpr()); \
} \
}); \

TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide);
TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide);
TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod);
TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod);
TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum);
TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal);
TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);

TVM_REGISTER_GLOBAL("topi.broadcast_to")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_to(args[0], args[1]);
});

} // namespace topi
154 changes: 154 additions & 0 deletions topi/src/elemwise.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief Registration of elemwise operators
* \file elemwise.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <topi/elemwise.h>

namespace topi {

using namespace tvm;
using namespace tvm::runtime;

TVM_REGISTER_GLOBAL("topi.exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = exp(args[0]);
});

TVM_REGISTER_GLOBAL("topi.fast_exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_exp(args[0]);
});

TVM_REGISTER_GLOBAL("topi.erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = erf(args[0]);
});

TVM_REGISTER_GLOBAL("topi.tan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tan(args[0]);
});

TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]);
});

TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]);
});

TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});

TVM_REGISTER_GLOBAL("topi.fast_tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_tanh(args[0]);
});

TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);
});

TVM_REGISTER_GLOBAL("topi.sigmoid")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sigmoid(args[0]);
});

TVM_REGISTER_GLOBAL("topi.sqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sqrt(args[0]);
});

TVM_REGISTER_GLOBAL("topi.rsqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rsqrt(args[0]);
});

TVM_REGISTER_GLOBAL("topi.log")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log(args[0]);
});

TVM_REGISTER_GLOBAL("topi.identity")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = identity(args[0]);
});

TVM_REGISTER_GLOBAL("topi.negative")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = negative(args[0]);
});

TVM_REGISTER_GLOBAL("topi.clip")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = clip(args[0], args[1], args[2]);
});

TVM_REGISTER_GLOBAL("topi.cast")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cast(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.reinterpret")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = reinterpret(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.elemwise_sum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = elemwise_sum(args[0]);
});

TVM_REGISTER_GLOBAL("topi.sign")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sign(args[0]);
});

TVM_REGISTER_GLOBAL("topi.full")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full(args[0], args[1], args[2]);
});

TVM_REGISTER_GLOBAL("topi.full_like")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full_like(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.logical_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = logical_not(args[0]);
});

TVM_REGISTER_GLOBAL("topi.bitwise_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = bitwise_not(args[0]);
});

} // namespace topi
Loading

0 comments on commit f67af3a

Please sign in to comment.