Skip to content

Commit

Permalink
allow variable composition (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored Jul 28, 2017
1 parent e9c2672 commit 9321635
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ inline std::vector<std::string> GetKeys(

// whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0;
Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) {
if (node != e.node.get()) return false;
}
return node->inputs.size() == 0 && node->control_deps.size() == 0;
}

// public functions
Expand Down Expand Up @@ -261,7 +264,14 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");

CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
for (size_t i = 0; i < outputs.size(); ++i) {
if (outputs[i].node->is_variable()) {
CHECK_EQ(args.size(), 0) << "Variable composition only supports keyword arguments";
const auto it = kwargs.find(outputs[i].node->attrs.name);
if (it != kwargs.end()) outputs[i] = it->second->outputs[0];
}
}

// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U)
Expand All @@ -271,13 +281,13 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
outputs[0].node->attrs.name = name;

// Atomic functor composition.
if (IsAtomic(outputs)) {
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();
// assign new name
if (!name.empty()) n->attrs.name = name;

if (n_req != kVarg) {
n->inputs.resize(n_req);
Expand Down

0 comments on commit 9321635

Please sign in to comment.