Skip to content

Commit

Permalink
Fix subgraph with custom_op (apache#15671)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhennanQin authored and Ubuntu committed Aug 20, 2019
1 parent ee24f26 commit cd798f4
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1046,8 +1046,10 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name,
for (auto property : subgraph_prop_list) {
nnvm::Graph g = Symbol2Graph(*s);
property->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
g = ApplyPass(std::move(g), "BuildSubgraph");
property->RemoveAttr("graph");
g.attrs.erase("subgraph_property");
s->outputs = g.outputs;
}
*ret_sym_handle = s;
Expand Down
4 changes: 3 additions & 1 deletion src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
g.outputs = s->outputs;
property->SetAttr("graph", g);
property->SetAttr("op_names", op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
g = nnvm::ApplyPass(std::move(g), "BuildSubgraph");
property->RemoveAttr("graph");
g.attrs.erase("subgraph_property");
s->outputs = g.outputs;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1688,8 +1688,10 @@ static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src, op::SubgraphPropertyP
g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, ctx_map, in_arg_ctxes,
aux_state_ctxes, true);
subgraph_prop->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(subgraph_prop));
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(subgraph_prop);
g = ApplyPass(std::move(g), "BuildSubgraph");
subgraph_prop->RemoveAttr("graph");
g.attrs.erase("subgraph_property");
ret.outputs = g.outputs;
return ret;
}
Expand Down
19 changes: 19 additions & 0 deletions src/operator/subgraph/subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ class SubgraphProperty {
auto it = attrs_.find(name);
return it != attrs_.end();
}
/*!
* \brief Remove attr if the attr exists.
*/
void RemoveAttr(const std::string& name) {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
attrs_.erase(it);
}
}
/*!
* \brief Get the property type.
*/
Expand Down Expand Up @@ -384,6 +393,16 @@ class SubgraphBackend {
return it != attrs_.end();
}

/*!
* \brief Remove attr if the attr exists.
*/
void RemoveAttr(const std::string& name) {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
attrs_.erase(it);
}
}

SubgraphPropertyPtr& RegisterSubgraphProperty(const SubgraphPropertyPtr prop) {
prop_ptr_.push_back(prop);
return prop_ptr_.back();
Expand Down
53 changes: 53 additions & 0 deletions tests/python/unittest/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,59 @@ def make_subgraph4(stype):
rtol=0.001, atol=0.0001)


def test_subgraph_with_customOp():
class MyAdd(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
self.assign(out_data[0], req[0], in_data[0] + 1)

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
self.assign(in_grad[0], req[0], out_grad[0])

@mx.operator.register('MyAdd1')
class MyAdd1Prop(mx.operator.CustomOpProp):
def __init__(self):
super(MyAdd1Prop, self).__init__(need_top_grad=True)

def list_arguments(self):
return ['data']

def list_outputs(self):
return ['output']

def infer_shape(self, in_shape):
# inputs, outputs, aux
return [in_shape[0]], [in_shape[0]], []

def create_operator(self, ctx, shapes, dtypes):
return MyAdd()

@mx.operator.register('MyAdd2')
class MyAdd2Prop(mx.operator.CustomOpProp):
def __init__(self):
super(MyAdd2Prop, self).__init__(need_top_grad=True)

def list_arguments(self):
return ['data']

def list_outputs(self):
return ['output']

def infer_shape(self, in_shape):
# inputs, outputs, aux
return [in_shape[0]], [in_shape[0]], []

def create_operator(self, ctx, shapes, dtypes):
return MyAdd()

inp = mx.nd.zeros(shape=(100, 100))
a = mx.symbol.Variable('a')
b = a + 1
b = mx.symbol.Custom(data=a, op_type='MyAdd1')
c = mx.symbol.Custom(data=a, op_type='MyAdd2')
b.bind(mx.cpu(), {'a': inp}).forward()
c.bind(mx.cpu(), {'a': inp}).forward()
mx.nd.waitall()

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit cd798f4

Please sign in to comment.