Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
wmdi committed Aug 28, 2023
1 parent 2fb2c7d commit a1bffc5
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 31 deletions.
6 changes: 4 additions & 2 deletions lib/utils/include/utils/graph/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,13 @@ using GraphSplit =
std::pair<OutputMultiDiEdge, InputMultiDiEdge> split_edge(MultiDiEdge const &e);
MultiDiEdge unsplit_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &);

std::unordered_set<MultiDiEdge> get_cut_set(MultiDiGraphView const &, GraphSplit const &);
std::unordered_set<MultiDiEdge> get_cut_set(OpenMultiDiGraphView const &,
GraphSplit const &);

bidict<MultiDiEdge, std::pair<OutputMultiDiEdge, InputMultiDiEdge>>
get_edge_splits(OpenMultiDiGraphView const &, GraphSplit const &);

std::unordered_set<MultiDiEdge> get_cut(OpenMultiDiGraphView const &,
GraphSplit const &);

UndirectedGraphView get_subgraph(UndirectedGraphView const &,
std::unordered_set<Node> const &);
Expand Down
48 changes: 47 additions & 1 deletion lib/utils/include/utils/graph/labelled/output_labelled_open.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,53 @@ struct OutputLabelledOpenMultiDiGraph {
OutputLabelledOpenMultiDiGraph &
operator=(OutputLabelledOpenMultiDiGraph const &) = default;

operator OpenMultiDiGraphView();
operator OpenMultiDiGraphView() {
NOT_IMPLEMENTED();
}

Node add_node(NodeLabel const &) {
NOT_IMPLEMENTED();
}
NodeLabel const &at(Node const &) const {
NOT_IMPLEMENTED();
}
NodeLabel &at(Node const &) const {
NOT_IMPLEMENTED();
}

void add_edge(MultiDiEdge const &) {
NOT_IMPLEMENTED();
}
void add_edge(InputMultiDiEdge const &) {
NOT_IMPLEMENTED();
}
void add_edge(OutputMultiDiEdge const &) {
NOT_IMPLEMENTED();
}

InputLabel const &at(InputMultiDiEdge const &) const {
NOT_IMPLEMENTED();
}
OutputLabel const &at(OutputMultiDiEdge const &) const {
NOT_IMPLEMENTED();
}

InputLabel &at(InputMultiDiEdge const &) {
NOT_IMPLEMENTED();
}
OutputLabel &at(OutputMultiDiEdge const &) {
NOT_IMPLEMENTED();
}

void add_output(MultiDiOutput const &, OutputLabel const &) {
NOT_IMPLEMENTED();
}
OutputLabel const &at(MultiDiOutput const &) const {
NOT_IMPLEMENTED();
}
OutputLabel &at(MultiDiOutput const &) {
NOT_IMPLEMENTED();
}
};

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/utils/include/utils/graph/open_graphs.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ struct OpenMultiDiGraphView {

OpenMultiDiGraphView() = delete;

operator MultiDiGraphView() const;

friend void swap(OpenMultiDiGraphView &, OpenMultiDiGraphView &);

std::unordered_set<Node> query_nodes(NodeQuery const &);
Expand Down
68 changes: 40 additions & 28 deletions lib/utils/src/graph/algorithms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,27 @@ std::unordered_set<Node> get_nodes(GraphView const &g) {
return g.query_nodes(NodeQuery::all());
}

std::unordered_set<Node> get_nodes(OpenMultiDiEdge const &pattern_edge) {
if (is_input_edge(pattern_edge)) {
return {mpark::get<InputMultiDiEdge>(pattern_edge).dst};
} else if (is_output_edge(pattern_edge)) {
return {mpark::get<OutputMultiDiEdge>(pattern_edge).src};
} else {
assert(is_standard_edge(pattern_edge));
auto standard_edge = mpark::get<MultiDiEdge>(pattern_edge);
return {standard_edge.src, standard_edge.dst};
std::unordered_set<Node> get_nodes(InputMultiDiEdge const &edge) {
return {edge.dst};
}

std::unordered_set<Node> get_nodes(OutputMultiDiEdge const &edge) {
return {edge.src};
}

std::unordered_set<Node> get_nodes(MultiDiEdge const &edge) {
return {edge.src, edge.src};
}

struct GetNodesFunctor {
template <typename T>
std::unordered_set<Node> operator()(T const &t) {
return get_nodes(t);
}
};

std::unordered_set<Node> get_nodes(OpenMultiDiEdge const &edge) {
return visit(GetNodesFunctor{}, edge);
}

std::unordered_set<Node> query_nodes(IGraphView const &g,
Expand Down Expand Up @@ -480,35 +491,36 @@ MultiDiEdge unsplit_edge(OutputMultiDiEdge const &output_edge,
output_edge.src, input_edge.dst, output_edge.srcIdx, input_edge.dstIdx};
}

bidict<MultiDiEdge, std::pair<OutputMultiDiEdge, InputMultiDiEdge>>
get_edge_splits(IOpenMultiDiGraphView const &pattern,
GraphSplit const &split) {
std::unordered_set<MultiDiEdge> get_cut_set(MultiDiGraphView const &graph, GraphSplit const &split) {
auto prefix = split.first;
auto postfix = split.second;

bidict<MultiDiEdge, std::pair<OutputMultiDiEdge, InputMultiDiEdge>> result;

for (OpenMultiDiEdge const &pattern_edge : get_edges(pattern)) {
if (!is_standard_edge(pattern_edge)) {
continue;
}
std::unordered_set<MultiDiEdge> result;

auto standard_edge = mpark::get<MultiDiEdge>(pattern_edge);
if (is_subseteq_of(get_nodes(standard_edge), prefix) ||
is_subseteq_of(get_nodes(standard_edge), postfix)) {
continue;
for (MultiDiEdge const &edge : get_edges(graph)) {
if (!is_subseteq_of(get_nodes(edge), prefix) &&
!is_subseteq_of(get_nodes(edge), postfix)) {
result.insert(edge);
}

auto divided = split_edge(standard_edge);
result.equate(standard_edge, divided);
}

return result;
}

std::unordered_set<MultiDiEdge> get_cut(OpenMultiDiGraphView const &g,
GraphSplit const &s) {
return keys(get_edge_splits(g, s));
std::unordered_set<MultiDiEdge> get_cut_set(OpenMultiDiGraphView const &graph,
GraphSplit const &split) {
return get_cut_set(graph, split);
}

bidict<MultiDiEdge, std::pair<OutputMultiDiEdge, InputMultiDiEdge>>
get_edge_splits(OpenMultiDiGraphView const &graph,
GraphSplit const &split) {
bidict<MultiDiEdge, std::pair<OutputMultiDiEdge, InputMultiDiEdge>> result;
std::unordered_set<MultiDiEdge> cut_set = get_cut_set(graph, split);
for (MultiDiEdge const &edge : cut_set) {
result.equate(edge, split_edge(edge));
}
return result;
}

Node get_src_node(MultiDiEdge const &) {
Expand Down

0 comments on commit a1bffc5

Please sign in to comment.