Skip to content

Commit

Permalink
[FIX] Make master compile
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 8521944 commit 79a0603
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
8 changes: 4 additions & 4 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def compute_max_pool2d(attrs, inputs, _):
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs["ceil_mode"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
assert ceil_mode == "False", "not support ceil_mode now"
assert not ceil_mode, "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='max')

@reg.register_schedule("max_pool2d")
Expand All @@ -165,9 +165,9 @@ def compute_avg_pool2d(attrs, inputs, _):
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs["ceil_mode"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
assert ceil_mode == "False", "not support ceil_mode now"
assert not ceil_mode, "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='avg')

@reg.register_schedule("avg_pool2d")
Expand Down
20 changes: 18 additions & 2 deletions nnvm/src/compiler/graph_fuse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,29 @@
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/lowered_func.h>
#include <dmlc/parameter.h>
#include "./compile_engine.h"
#include "../../tvm/src/runtime/graph/graph_runtime.h"

namespace nnvm {
namespace compiler {

using tvm::runtime::TVMOpParam;

struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;

DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
}
};

DMLC_REGISTER_PARAMETER(TVMOpParam);

// parser
inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
Expand Down Expand Up @@ -368,7 +384,7 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
nnvm::NodePtr np = nnvm::Node::Create();
np->attrs.op = tvm_op;
np->attrs.name = inode.source->attrs.name;
runtime::TVMOpParam param;
TVMOpParam param;
param.func_name = fe.compiled_func->func_name;
param.num_inputs = static_cast<uint32_t>(fe.imap.size());
param.num_outputs = static_cast<uint32_t>(fe.subgraph.outputs.size());
Expand Down

0 comments on commit 79a0603

Please sign in to comment.