Skip to content

Commit

Permalink
[Collage] SubGraphs (apache#11981)
Browse files Browse the repository at this point in the history
* [Collage] SubGraphs

See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md.

Collage works in units of 'sub-graphs', which are potential partitions of the
overall Relay model. This PR introduces SubGraph (an arbitrary partitioning, without
any implication about how it is to be represented), it's companion SubSubGraph
(implying a representation as a function), and some supporting odds 'n ends.

* - make Integer <-> size_t conversion explicit
- make 'Compiler' name explicit

* - fix namespace ambiguity

* - review comments
  • Loading branch information
mbs-octoml authored and Mikael Sevenier committed Jul 26, 2022
1 parent 5a82893 commit 36fd2dd
Show file tree
Hide file tree
Showing 11 changed files with 2,609 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ tvm_file_glob(GLOB_RECURSE RELAY_OP_SRCS
)
tvm_file_glob(GLOB_RECURSE RELAY_PASS_SRCS
src/relay/analysis/*.cc
src/relay/collage/*.cc
src/relay/transforms/*.cc
src/relay/quantize/*.cc
)
Expand Down
26 changes: 26 additions & 0 deletions src/relay/collage/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->

The `CollagePartition` pass for finding optimal partitionings of Relay models.

See the [RFC](https://github.com/mbs-octoml/mbs-tvm-rfcs/blob/mbs-rfcs-collage/rfcs/xxxx-collage.md).

Based on:
> *Collage: Automated Integration of Deep Learning Backends*
> Byungsoo Jeon, Sunghyun Park, Peiyuan Liao, Sheng Xu, Tianqi Chen, Zhihao Jia
CAUTION: This is a prototype, do not use in prod.
48 changes: 48 additions & 0 deletions src/relay/collage/dataflow_graph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/collage/dataflow_graph.cc
* \brief A representation of the dataflow for an overall Relay expression.
*/

#include "./dataflow_graph.h"

namespace tvm {
namespace relay {
namespace collage {

DataflowGraph::DataflowGraph(Expr expr) : expr_(std::move(expr)) {
indexed_graph_ = CreateIndexedGraph(expr_);
downstream_map_.reserve(indexed_graph_->size());
for (PostDfsIndex index = 0; index < indexed_graph_->size(); ++index) {
const Node* node = indexed_graph_->index_to_node(index);
std::unordered_set<const Node*> downstream_nodes;
node->AccumulateDownstreamNodes(&downstream_nodes);
IndexSet index_set(indexed_graph_->size());
for (const Node* downstream_node : downstream_nodes) {
index_set.Add(downstream_node->index_);
}
downstream_map_.emplace_back(std::move(index_set));
}
}

} // namespace collage
} // namespace relay
} // namespace tvm
77 changes: 77 additions & 0 deletions src/relay/collage/dataflow_graph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/collage/dataflow_graph.h
* \brief A representation of the dataflow for an overall Relay expression.
*/
#ifndef TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_
#define TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_

#include <tvm/relay/expr.h>

#include <memory>
#include <vector>

#include "../ir/indexed_graph.h"
#include "./index_set.h"

namespace tvm {
namespace relay {
namespace collage {

/*!
* \brief Represents the dataflow of an overall Relay expression.
*/
class DataflowGraph {
public:
using Node = IndexedGraph<Expr>::Node;

explicit DataflowGraph(Expr expr);

size_t size() const { return indexed_graph_->size(); }
const Node* index_to_node(PostDfsIndex index) const {
return indexed_graph_->index_to_node(index);
}
const Node* item_to_node(const Expr& expr) const { return indexed_graph_->item_to_node(expr); }
const Node* item_to_node(const ExprNode* expr_node) const {
return indexed_graph_->item_to_node(expr_node);
}
const Expr& expr() const { return expr_; }
const IndexedGraph<Expr>& indexed_graph() const { return *indexed_graph_; }

const IndexSet& downstream_of(PostDfsIndex index) const {
ICHECK_LT(index, indexed_graph_->size());
return downstream_map_[index];
}

private:
/*! \brief The overall expression. */
Expr expr_;
/*! \brief The indexed graph which captures the main dataflow. */
std::unique_ptr<IndexedGraph<Expr>> indexed_graph_;
/*! \brief Map from a node's PostDfsIndex to the set of its downstream dataflow node indexes. */
std::vector<IndexSet> downstream_map_;
};

} // namespace collage
} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_
231 changes: 231 additions & 0 deletions src/relay/collage/index_set.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/collage/index_set.cc
* \brief Efficient representation of a set of post-dfs indexes.
*/

#include "./index_set.h"

namespace tvm {
namespace relay {
namespace collage {

// TODO(mbs): These should operate one-word-at-a-time

IndexSet::IndexSet(size_t size, const std::vector<size_t>& indexes) : bitvec_(size, false) {
for (size_t index : indexes) {
ICHECK_LT(index, bitvec_.size());
ICHECK(!bitvec_[index]);
bitvec_[index] = true;
}
}

IndexSet IndexSet::operator&(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
std::vector<bool> result(bitvec_.size(), false);
for (size_t index = 0; index < bitvec_.size(); ++index) {
result[index] = bitvec_[index] && that.bitvec_[index];
}
return IndexSet(result);
}

IndexSet IndexSet::operator|(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
std::vector<bool> result(bitvec_.size(), false);
for (size_t index = 0; index < bitvec_.size(); ++index) {
result[index] = bitvec_[index] || that.bitvec_[index];
}
return IndexSet(result);
}

IndexSet IndexSet::operator-(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
std::vector<bool> result(bitvec_.size());
for (size_t index = 0; index < bitvec_.size(); ++index) {
result[index] = bitvec_[index] && !that.bitvec_[index];
}
return IndexSet(result);
}

bool IndexSet::AreDisjoint(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
for (size_t index = 0; index < bitvec_.size(); index++) {
if (bitvec_[index] && that.bitvec_[index]) {
return false;
}
}
return true;
}

bool IndexSet::IsSubset(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
for (size_t index = 0; index < bitvec_.size(); index++) {
if (bitvec_[index] && !that.bitvec_[index]) {
return false;
}
}
return true;
}

bool IndexSet::Intersects(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
for (size_t index = 0; index < bitvec_.size(); index++) {
if (bitvec_[index] && that.bitvec_[index]) {
return true;
}
}
return false;
}

IndexSet IndexSet::Subst(size_t new_size, const IndexSubst& subst) const {
std::vector<bool> result(new_size, false);
for (PostDfsIndex index = 0; index < bitvec_.size(); ++index) {
if (!bitvec_[index]) {
continue;
}
auto itr = subst.find(index);
ICHECK(itr != subst.end());
PostDfsIndex new_index = itr->second;
ICHECK(new_index < new_size);
ICHECK(!result[new_index]);
result[new_index] = true;
}
return IndexSet(result);
}

size_t IndexSet::PopCount() const {
size_t n = 0;
for (size_t index = 0; index < bitvec_.size(); index++) {
if (bitvec_[index]) {
++n;
}
}
return n;
}

bool IndexSet::IsZero() const {
for (size_t index = 0; index < bitvec_.size(); index++) {
if (bitvec_[index]) {
return false;
}
}
return true;
}

size_t IndexSet::FirstInsideIndex() const {
for (size_t index = 0; index < bitvec_.size(); index++) {
if (bitvec_[index]) {
return index;
}
}
return bitvec_.size();
}

size_t IndexSet::LastInsideIndex() const {
for (size_t i = bitvec_.size(); i > 0; i--) {
const size_t index = i - 1;
if (bitvec_[index]) {
return index;
}
}
return bitvec_.size();
}

size_t IndexSet::NextIndex(size_t index) const {
ICHECK_LT(index, bitvec_.size());
for (index++; index < bitvec_.size(); index++) {
if (bitvec_[index]) {
return index;
}
}
return bitvec_.size();
}

size_t IndexSet::FirstOutsideIndex() const {
for (size_t index = 0; index < bitvec_.size(); index++) {
if (!bitvec_[index]) {
return index;
}
}
return bitvec_.size();
}

bool IndexSet::operator==(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
return bitvec_ == that.bitvec_;
}

bool IndexSet::operator!=(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
return bitvec_ != that.bitvec_;
}

bool IndexSet::operator<(const IndexSet& that) const {
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
for (size_t index = 0; index < bitvec_.size(); index++) {
if (bitvec_[index] && !that.bitvec_[index]) {
return true;
}
if (!bitvec_[index] && that.bitvec_[index]) {
return false;
}
}
return false;
}

size_t IndexSet::hash() const {
std::hash<std::vector<bool>> h;
return h(bitvec_);
}

std::string IndexSet::ToString() const {
std::ostringstream os;
os << "{";
bool first = true;
for (size_t start = 0; start < bitvec_.size(); /*no-op*/) {
if (!bitvec_[start]) {
++start;
continue;
}
size_t end;
for (end = start + 1; end < bitvec_.size() && bitvec_[end]; ++end) {
/*no-op*/
}
if (first) {
first = false;
} else {
os << ",";
}
os << start;
if (end > start + 2) {
os << ".." << (end - 1);
start = end;
} else {
++start;
}
}
os << "}";
return os.str();
}

} // namespace collage
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 36fd2dd

Please sign in to comment.