diff --git a/CMakeLists.txt b/CMakeLists.txt index a518931ac5..7405ded00d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,7 @@ option(FF_BUILD_ALL_EXAMPLES "build all examples. Overrides others" OFF) option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" ON) +option(FF_BUILD_SP_IZATION_BENCHMARKING "build sp-ization benchmarking" ON) option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) set(FF_CUDA_ARCH "autodetect" CACHE STRING "Target CUDA Arch") diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index fcc19b33b9..9af1ed0985 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -10,6 +10,10 @@ if(FF_BUILD_VISUALIZATION_TOOL) add_subdirectory(substitutions-to-dot) endif() +if(FF_BUILD_SP_IZATION_BENCHMARKING) + add_subdirectory(sp_ization_benchmarking) +endif() + if(FF_BUILD_ARG_PARSER) add_subdirectory(arg_parser) endif() diff --git a/bin/sp_ization_benchmarking/CMakeLists.txt b/bin/sp_ization_benchmarking/CMakeLists.txt new file mode 100644 index 0000000000..a24e84e31d --- /dev/null +++ b/bin/sp_ization_benchmarking/CMakeLists.txt @@ -0,0 +1,9 @@ +ff_add_executable( + NAME + sp-ization-benchmarking + SRC_PATTERNS + *.cc + DEPS + utils + rapidcheck +) diff --git a/bin/sp_ization_benchmarking/distributions.cc b/bin/sp_ization_benchmarking/distributions.cc new file mode 100644 index 0000000000..6c59f58a34 --- /dev/null +++ b/bin/sp_ization_benchmarking/distributions.cc @@ -0,0 +1,55 @@ +#include "distributions.h" + +namespace FlexFlow { + +Constant::Constant(float val) : val(val) {} + +float Constant::operator()() const { + return val; +} + +Uniform::Uniform(float a, float b) : a(a), b(b) {} + +float Uniform::operator()() const { + return a + ((static_cast(std::rand()) / RAND_MAX) * (b - a)); +} + +Bernoulli::Bernoulli(float p) : p(p) {} + +float Bernoulli::operator()() const { + return (Uniform(0, 1)() < p); +} + +Binary::Binary(float a, float b, float p) : a(a), b(b), p(p) {} + +float Binary::operator()() const { + return (Bernoulli(p)() ? a : b); +} + +Chooser::Chooser(std::vector items) : items(items) {} + +float Chooser::operator()() const { + return items[std::rand() % items.size()]; +} + +UniformNoise::UniformNoise(float lower, float upper) + : lower(lower), upper(upper) {} + +float UniformNoise::operator()() const { + return Uniform(lower, upper)(); +} + +float NoNoise::operator()() const { + return 1; +} + +GaussianNoise::GaussianNoise(float mean, float stddev) + : mean(mean), stddev(stddev) {} + +float GaussianNoise::operator()() const { + static std::default_random_engine generator; + static std::normal_distribution distribution(mean, stddev); + return distribution(generator); +} + +} // namespace FlexFlow diff --git a/bin/sp_ization_benchmarking/distributions.h b/bin/sp_ization_benchmarking/distributions.h new file mode 100644 index 0000000000..ea24d55898 --- /dev/null +++ b/bin/sp_ization_benchmarking/distributions.h @@ -0,0 +1,81 @@ +#ifndef _FLEXFLOW_DISTRIBUTIONS_H +#define _FLEXFLOW_DISTRIBUTIONS_H + +#include "utils/graph/node/node.dtg.h" +#include +#include +#include + +namespace FlexFlow { + +struct Constant { + float val; + Constant(float val = 1); + float operator()() const; +}; + +struct Uniform { + float a, b; + Uniform(float a = 0, float b = 1); + float operator()() const; +}; + +struct Bernoulli { + float p; + Bernoulli(float p = 0.5); + float operator()() const; +}; + +struct Binary { + float a, b, p; + Binary(float a = 0, float b = 1, float p = 0.5); + float operator()() const; +}; + +struct Chooser { + std::vector items; + Chooser(std::vector); + float operator()() const; +}; + +struct UniformNoise { + float lower, upper; + UniformNoise(float lower = 0.9, float upper = 1.1); + float operator()() const; +}; + +struct NoNoise { + float operator()() const; +}; + +struct GaussianNoise { + float mean, stddev; + GaussianNoise(float mean = 1, float stddev = .1); + float operator()() const; +}; + +template +std::unordered_map + make_cost_map(std::unordered_set const &nodes, + Dist const &distribution) { + std::unordered_map cost_map; + for (Node const &node : nodes) { + cost_map[node] = distribution(); + } + return cost_map; +} + +template +std::unordered_map + add_noise_to_cost_map(std::unordered_map cost_map, + Noise const &noise) { + std::unordered_map noisy_cost_map; + for (auto const &[node, cost] : cost_map) { + noisy_cost_map[node] = noise() * cost; + } + return noisy_cost_map; +} + +} // namespace FlexFlow + +#endif diff --git a/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h b/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h new file mode 100644 index 0000000000..71c942976f --- /dev/null +++ b/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h @@ -0,0 +1,126 @@ +// For context, see https://arxiv.org/abs/1902.09635 && +// https://github.com/google-research/nasbench/blob/master/nasbench/api.py + +#include "utils/containers.h" +#include "utils/containers/all_of.h" +#include "utils/containers/repeat.h" +#include "utils/containers/transform.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/serial_parallel/digraph_generation.h" +#include +#include + +constexpr size_t MIN_NODES = 6; +constexpr size_t MAX_NODES = 8; +constexpr size_t MIN_EDGES = 8; +constexpr size_t MAX_EDGES = 11; +constexpr size_t NUM_CELLS = 9; + +using AdjacencyMatrix = std::vector>; +namespace FlexFlow { +struct NasNetBenchConfig { + AdjacencyMatrix adjacency_matrix; +}; + +bool is_valid_config(NasNetBenchConfig const &config) { + AdjacencyMatrix const &matrix = config.adjacency_matrix; + const size_t size = matrix.size(); + + auto is_valid_size = [](size_t s) { + return s >= MIN_NODES && s <= MAX_NODES; + }; + + auto is_square_matrix = [&](auto const &m) { + return all_of(m, [&](const auto &row) { return row.size() == size; }); + }; + + auto is_upper_triangular = [&](auto const &m) { + for (size_t i = 0; i < size; ++i) { + for (size_t j = 0; j <= i; ++j) { + if (matrix[i][j]) { + return false; + } + } + } + return true; + }; + + return is_valid_size(size) && is_square_matrix(matrix) && + is_upper_triangular(matrix); +} + +bool is_valid_cell(DiGraphView const &g) { + return (is_acyclic(g)) && (get_sources(g).size() == 1) && + (get_sinks(g).size() == 1) && (num_edges(g) <= MAX_EDGES) && + (num_edges(g) >= MIN_EDGES) && (num_edges(g) <= MAX_NODES) && + (num_edges(g) >= MIN_NODES) && + (num_edges(g) > num_nodes(g)); // filter linear cell and diamond cell +} + +NasNetBenchConfig generate_random_config() { + static std::uniform_int_distribution<> size_dist(MIN_NODES, MAX_NODES); + Binary bin = Binary(0, 1); + + size_t num_nodes = Uniform(MIN_NODES, MAX_NODES)(); + std::vector> matrix(num_nodes, + std::vector(num_nodes, false)); + + for (size_t i = 0; i < num_nodes; ++i) { + for (size_t j = i + 1; j < num_nodes; ++j) { + matrix[i][j] = bin(); + } + } + + return {matrix}; +} + +std::optional + maybe_generate_nasnet_bench_cell(NasNetBenchConfig const &config) { + if (!is_valid_config(config)) { + return std::nullopt; + } + + DiGraph g = DiGraph::create(); + std::vector nodes = add_nodes(g, config.adjacency_matrix.size()); + + for (size_t i = 0; i < nodes.size(); ++i) { + for (size_t j = i + 1; j < nodes.size(); ++j) { + if (config.adjacency_matrix[i][j]) { + g.add_edge(DirectedEdge{nodes[i], nodes[j]}); + } + } + } + + g = materialize_digraph_view(transitive_reduction(g)); + + if (!is_valid_cell(g)) { + return std::nullopt; + } + + return g; +} + +DiGraph generate_nasnet_bench_cell() { + while (true) { + NasNetBenchConfig config = generate_random_config(); + std::optional maybe_cell = + maybe_generate_nasnet_bench_cell(config); + if (maybe_cell) { + return maybe_cell.value(); + } + } +} + +DiGraph generate_nasnet_bench_network() { + DiGraph g = serial_composition( + transform(repeat(NUM_CELLS, generate_nasnet_bench_cell), + [](auto const cell) -> DiGraphView { return cell; })); + return g; +} +} // namespace FlexFlow diff --git a/bin/sp_ization_benchmarking/sample_graphs.h b/bin/sp_ization_benchmarking/sample_graphs.h new file mode 100644 index 0000000000..a3286f3337 --- /dev/null +++ b/bin/sp_ization_benchmarking/sample_graphs.h @@ -0,0 +1,351 @@ +#ifndef FLEXFLOW_GRAPH_GENERATION_H +#define FLEXFLOW_GRAPH_GENERATION_H + +#include "distributions.h" +#include "sample_graphs.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/serial_parallel/digraph_generation.h" +#include + +namespace FlexFlow { + +std::tuple make_normal_taso_nasnet_cell() { + DiGraph g = DiGraph::create(); + std::vector inputs = add_nodes(g, 2); + std::vector sep = add_nodes(g, 5); + std::vector id = add_nodes(g, 2); + std::vector avg = add_nodes(g, 3); + std::vector add = add_nodes(g, 5); + std::vector concat = add_nodes(g, 1); + + std::vector edges = {DirectedEdge{inputs.at(0), sep.at(1)}, + DirectedEdge{inputs.at(0), id.at(1)}, + DirectedEdge{inputs.at(0), avg.at(1)}, + DirectedEdge{inputs.at(0), avg.at(2)}, + DirectedEdge{inputs.at(0), sep.at(3)}, + DirectedEdge{inputs.at(0), sep.at(4)}, + DirectedEdge{inputs.at(1), sep.at(0)}, + DirectedEdge{inputs.at(1), id.at(0)}, + DirectedEdge{inputs.at(1), avg.at(0)}, + DirectedEdge{inputs.at(1), sep.at(2)}, + DirectedEdge{sep.at(0), add.at(0)}, + DirectedEdge{id.at(0), add.at(0)}, + DirectedEdge{sep.at(1), add.at(1)}, + DirectedEdge{sep.at(2), add.at(1)}, + DirectedEdge{avg.at(0), add.at(2)}, + DirectedEdge{id.at(1), add.at(2)}, + DirectedEdge{avg.at(1), add.at(3)}, + DirectedEdge{avg.at(2), add.at(3)}, + DirectedEdge{sep.at(3), add.at(4)}, + DirectedEdge{sep.at(4), add.at(4)}}; + add_edges(g, edges); + + for (Node const &a : add) { + g.add_edge(DirectedEdge{a, concat.at(0)}); + } + + assert(get_sinks(g).size() == 1); + assert(get_sources(g).size() == 2); + assert(is_acyclic(g)); + return {g, inputs.at(0), inputs.at(1)}; +} + +std::tuple make_reduction_taso_nasnet_cell() { + DiGraph g = DiGraph::create(); + std::vector inputs = add_nodes(g, 2); + std::vector sep = add_nodes(g, 5); + std::vector id = add_nodes(g, 1); + std::vector avg = add_nodes(g, 2); + std::vector max = add_nodes(g, 2); + std::vector add = add_nodes(g, 5); + std::vector concat = add_nodes(g, 1); + + std::vector edges = {DirectedEdge{inputs.at(0), sep.at(0)}, + DirectedEdge{inputs.at(0), sep.at(2)}, + DirectedEdge{inputs.at(0), sep.at(3)}, + DirectedEdge{inputs.at(1), max.at(1)}, + DirectedEdge{inputs.at(1), sep.at(1)}, + DirectedEdge{inputs.at(1), max.at(0)}, + DirectedEdge{inputs.at(1), avg.at(0)}, + DirectedEdge{sep.at(0), add.at(0)}, + DirectedEdge{sep.at(1), add.at(0)}, + DirectedEdge{max.at(0), add.at(1)}, + DirectedEdge{sep.at(2), add.at(1)}, + DirectedEdge{avg.at(0), add.at(2)}, + DirectedEdge{sep.at(3), add.at(2)}, + DirectedEdge{max.at(1), add.at(3)}, + DirectedEdge{sep.at(4), add.at(3)}, + DirectedEdge{avg.at(1), add.at(4)}, + DirectedEdge{id.at(0), add.at(4)}, + DirectedEdge{add.at(0), sep.at(4)}, + DirectedEdge{add.at(0), avg.at(1)}, + DirectedEdge{add.at(1), id.at(0)}, + DirectedEdge{add.at(2), concat.at(0)}, + DirectedEdge{add.at(3), concat.at(0)}, + DirectedEdge{add.at(4), concat.at(0)}}; + + add_edges(g, edges); + + assert(get_sinks(g).size() == 1); + assert(get_sources(g).size() == 2); + assert(is_acyclic(g)); + return {g, inputs.at(0), inputs.at(1)}; +} + +DiGraph make_full_taso_nasnet(size_t num_reduction_cells, size_t N) { + DiGraph g = DiGraph::create(); + Node input = g.add_node(); + std::deque outputting = {input, input, input}; + std::deque inputting; + size_t num_cells = num_reduction_cells + N * (num_reduction_cells + 1); + for (int i = 0; i < num_cells; i++) { + auto [s, earlier_input, later_input] = + (i % (N + 1) == N) ? make_reduction_taso_nasnet_cell() + : make_normal_taso_nasnet_cell(); + Node cell_output = get_only(get_sinks(s)); + std::unordered_map node_map = parallel_extend(g, s); + later_input = node_map.at(later_input); + earlier_input = node_map.at(earlier_input); + cell_output = node_map.at(cell_output); + + outputting.push_back(cell_output); + outputting.push_back(cell_output); + inputting.push_back(earlier_input); + inputting.push_back(later_input); + + Node a = outputting.front(); + Node b = inputting.front(); + inputting.pop_front(); + outputting.pop_front(); + g.add_edge(DirectedEdge{a, b}); + + a = outputting.front(); + b = inputting.front(); + inputting.pop_front(); + outputting.pop_front(); + g.add_edge(DirectedEdge{a, b}); + + assert(is_2_terminal_dag(g)); + assert(inputting.size() == 0); + assert(outputting.size() == 3); + } + return g; +} + +DiGraph make_linear(size_t length) { + DiGraph g = DiGraph::create(); + if (length == 0) { + return g; + } + std::vector nodes = add_nodes(g, length); + + for (size_t i = 0; i < length - 1; ++i) { + g.add_edge(DirectedEdge{nodes.at(i), nodes.at(i + 1)}); + } + + return g; +} + +DiGraph make_rhombus() { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}}; + + add_edges(g, edges); + return g; +} + +DiGraph make_diamond() { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + + std::vector edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + }; + + add_edges(g, edges); + return g; +} + +DiGraph make_fully_connected(std::vector layer_sizes) { + DiGraph g = DiGraph::create(); + std::vector> layers = + transform(layer_sizes, [&g](size_t size) { return add_nodes(g, size); }); + + std::vector edges; + + for (size_t i = 0; i < layers.size() - 1; ++i) { + for (Node const &n1 : layers.at(i)) { + for (Node const &n2 : layers.at(i + 1)) { + edges.push_back(DirectedEdge{n1, n2}); + } + } + } + + add_edges(g, edges); + return g; +} + +DiGraph make_parallel_chains(size_t chain_length, size_t chain_num) { + DiGraph g = DiGraph::create(); + assert(chain_length >= 3); + assert(chain_num >= 1); + std::vector> chains; + + for (size_t i = 0; i < chain_num; i++) { + std::vector chain_nodes = add_nodes(g, chain_length - 2); + chains.push_back(chain_nodes); + + for (size_t j = 0; j < chain_length - 3; j++) { + g.add_edge(DirectedEdge{chain_nodes.at(j), chain_nodes.at(j + 1)}); + } + } + + Node source = g.add_node(); + Node sink = g.add_node(); + + for (std::vector const &chain : chains) { + g.add_edge(DirectedEdge{source, chain.front()}); + g.add_edge(DirectedEdge{chain.back(), sink}); + } + + return g; +} + +DiGraph make_sample_dag_1() { + + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 7); + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + DirectedEdge{n.at(0), n.at(6)}, + DirectedEdge{n.at(2), n.at(6)}, + DirectedEdge{n.at(6), n.at(5)}}; + add_edges(g, edges); + assert(is_2_terminal_dag(g)); + return g; +} + +DiGraph make_sample_dag_2() { + NOT_IMPLEMENTED(); +} + +DiGraph make_sample_dag_3() { + // Taken by "A New Algorithm for Mapping DAGs to Series-ParallelSplit Form, + // Escribano et Al, 2002" + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 18); + + std::vector edges = { + DirectedEdge{n.at(0), n.at(1)}, DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(10)}, DirectedEdge{n.at(2), n.at(11)}, + DirectedEdge{n.at(2), n.at(12)}, DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(3), n.at(6)}, DirectedEdge{n.at(4), n.at(6)}, + DirectedEdge{n.at(4), n.at(7)}, DirectedEdge{n.at(4), n.at(10)}, + DirectedEdge{n.at(5), n.at(8)}, DirectedEdge{n.at(6), n.at(8)}, + DirectedEdge{n.at(6), n.at(9)}, DirectedEdge{n.at(7), n.at(8)}, + DirectedEdge{n.at(8), n.at(17)}, DirectedEdge{n.at(9), n.at(17)}, + DirectedEdge{n.at(10), n.at(16)}, DirectedEdge{n.at(11), n.at(16)}, + DirectedEdge{n.at(12), n.at(13)}, DirectedEdge{n.at(12), n.at(14)}, + DirectedEdge{n.at(13), n.at(15)}, DirectedEdge{n.at(14), n.at(15)}, + DirectedEdge{n.at(15), n.at(16)}, DirectedEdge{n.at(16), n.at(17)}}; + + add_edges(g, edges); + return g; +} + +DiGraph make_taso_nasnet_cell() { + // From the TASO paper, pg 57 + DiGraph g = DiGraph::create(); + Node root = g.add_node(); + std::vector input = add_nodes(g, 2); + std::vector dwc = add_nodes(g, 5); + std::vector conv = add_nodes(g, 5); + std::vector avg = add_nodes(g, 3); + std::vector add = add_nodes(g, 5); + Node concat = g.add_node(); + + std::vector edges = {DirectedEdge{root, input.at(0)}, + DirectedEdge{root, input.at(1)}, + DirectedEdge{input.at(0), dwc.at(0)}, + DirectedEdge{input.at(0), dwc.at(1)}, + DirectedEdge{input.at(0), avg.at(0)}, + DirectedEdge{input.at(0), avg.at(1)}, + DirectedEdge{input.at(0), avg.at(2)}, + DirectedEdge{input.at(0), dwc.at(2)}, + DirectedEdge{input.at(1), add.at(2)}, + DirectedEdge{input.at(1), dwc.at(3)}, + DirectedEdge{input.at(1), dwc.at(4)}, + DirectedEdge{input.at(1), add.at(4)}, + DirectedEdge{dwc.at(0), conv.at(0)}, + DirectedEdge{dwc.at(1), conv.at(1)}, + DirectedEdge{dwc.at(2), conv.at(2)}, + DirectedEdge{dwc.at(3), conv.at(3)}, + DirectedEdge{dwc.at(4), conv.at(4)}, + DirectedEdge{conv.at(0), add.at(0)}, + DirectedEdge{conv.at(1), add.at(0)}, + DirectedEdge{avg.at(0), add.at(1)}, + DirectedEdge{avg.at(1), add.at(1)}, + DirectedEdge{avg.at(2), add.at(2)}, + DirectedEdge{conv.at(2), add.at(3)}, + DirectedEdge{conv.at(3), add.at(3)}, + DirectedEdge{conv.at(4), add.at(4)}}; + + add_edges(g, edges); + + for (Node const &a : add) { + g.add_edge(DirectedEdge{a, concat}); + } + return g; +} + +DiGraph make_2_terminal_random_dag(size_t num_nodes, float p, size_t step) { + DiGraph g = DiGraph::create(); + Bernoulli sampler = Bernoulli(p); + std::vector n = add_nodes(g, num_nodes - 2); + for (int i = 0; i < n.size(); i++) { + for (int j = i + step + 1; j < n.size(); j++) { + if (sampler()) { + g.add_edge(DirectedEdge{n.at(i), n.at(j)}); + } + } + } + std::unordered_set sinks = get_sinks(g); + std::unordered_set sources = get_sources(g); + Node sink = g.add_node(); + Node source = g.add_node(); + for (Node s : sources) { + g.add_edge(DirectedEdge{source, s}); + } + for (Node s : sinks) { + g.add_edge(DirectedEdge{s, sink}); + } + assert(is_2_terminal_dag(g)); + return g; +} + +} // namespace FlexFlow + +#endif diff --git a/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc b/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc new file mode 100644 index 0000000000..5febc22474 --- /dev/null +++ b/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc @@ -0,0 +1,585 @@ +/** + * @file sp_ization_benchmarking.cpp + * @brief Benchmarking different SP-ization techniques on various graphs. + * + * @details + * Algorithms: + * - critical_path_preserving_sp_ization_with_coalescing + * - stratum_sync_sp_ization + * - cost_aware_stratum_sync_sp_ization + * Weight distributions: + * - Constant + * - Uniform(0, 1) + * - Binary(0, 100) + * - Chooser({1.0, 25.0, 500.0}) //sample uniformly from the given weights + * Noise distributions: + * - NoNoise + * - GaussianNoise(1, 0.1) + * - UniformNoise(0.8, 1.25) + * Graph types: + * ... + * + * @note To run the benchmark, go to build/normal/bin/sp_ization_benchmarking, + * run make and then ./sp_ization_benchmarking + */ + +#include "distributions.h" +#include "nasnet_bench_graph_generator.h" +#include "sample_graphs.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/serial_parallel/serial_parallel_metrics.h" +#include "utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h" +#include "utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h" +#include +#include +#include +#include + +constexpr size_t REPEAT = 500; + +using namespace FlexFlow; +using Result = std::tuple; +using CombinedResult = std::tuple; + +template +CombinedResult perform_benchmark_given_graph(DiGraphView const &g, + D const &Dist, + N const &Noise = NoNoise(), + size_t repeat = REPEAT) { + Result critical_path_preserving = {0, 0, 0}; + Result barrier_sync = {0, 0, 0}; + Result cost_aware = {0, 0, 0}; + + for (int i = 0; i < repeat; i++) { + auto cost_map = make_cost_map(get_nodes(g), Dist); + + SerialParallelDecomposition sp1 = + critical_path_preserving_sp_ization_with_coalescing(g); + SerialParallelDecomposition sp2 = stratum_sync_sp_ization(g); + SerialParallelDecomposition sp3 = + cost_aware_stratum_sync_sp_ization(g, cost_map); + + auto noisy_cost_map = add_noise_to_cost_map(cost_map, Noise); + + std::get<0>(critical_path_preserving) += + relative_work_increase(g, sp1, noisy_cost_map); + std::get<1>(critical_path_preserving) += + relative_critical_path_cost_increase(g, sp1, noisy_cost_map); + std::get<2>(critical_path_preserving) += + relative_num_dependencies_increase(g, sp1); + + std::get<0>(barrier_sync) += relative_work_increase(g, sp2, noisy_cost_map); + std::get<1>(barrier_sync) += + relative_critical_path_cost_increase(g, sp2, noisy_cost_map); + std::get<2>(barrier_sync) += relative_num_dependencies_increase(g, sp2); + + std::get<0>(cost_aware) += relative_work_increase(g, sp3, noisy_cost_map); + std::get<1>(cost_aware) += + relative_critical_path_cost_increase(g, sp3, noisy_cost_map); + std::get<2>(cost_aware) += relative_num_dependencies_increase(g, sp3); + } + + std::vector results = { + critical_path_preserving, barrier_sync, cost_aware}; + + for (Result &r : results) { + std::get<0>(r) /= repeat; + std::get<1>(r) /= repeat; + std::get<2>(r) /= repeat; + } + + return {results[0], results[1], results[2]}; +} + +template +CombinedResult + perform_benchmark_given_graph_generator(G const &graph_generator, + D const &Dist, + N const &Noise = NoNoise(), + size_t repeat = REPEAT) { + Result critical_path_preserving = {0, 0, 0}; + Result barrier_sync = {0, 0, 0}; + Result cost_aware = {0, 0, 0}; + + for (int i = 0; i < repeat; i++) { + DiGraphView g = graph_generator(); + auto cost_map = make_cost_map(get_nodes(g), Dist); + + SerialParallelDecomposition sp1 = + critical_path_preserving_sp_ization_with_coalescing(g); + SerialParallelDecomposition sp2 = stratum_sync_sp_ization(g); + SerialParallelDecomposition sp3 = + cost_aware_stratum_sync_sp_ization(g, cost_map); + + auto noisy_cost_map = add_noise_to_cost_map(cost_map, Noise); + + std::get<0>(critical_path_preserving) += + relative_work_increase(g, sp1, noisy_cost_map); + std::get<1>(critical_path_preserving) += + relative_critical_path_cost_increase(g, sp1, noisy_cost_map); + std::get<2>(critical_path_preserving) += + relative_num_dependencies_increase(g, sp1); + + std::get<0>(barrier_sync) += relative_work_increase(g, sp2, noisy_cost_map); + std::get<1>(barrier_sync) += + relative_critical_path_cost_increase(g, sp2, noisy_cost_map); + std::get<2>(barrier_sync) += relative_num_dependencies_increase(g, sp2); + + std::get<0>(cost_aware) += relative_work_increase(g, sp3, noisy_cost_map); + std::get<1>(cost_aware) += + relative_critical_path_cost_increase(g, sp3, noisy_cost_map); + std::get<2>(cost_aware) += relative_num_dependencies_increase(g, sp3); + } + + std::vector results = { + critical_path_preserving, barrier_sync, cost_aware}; + + for (Result &r : results) { + std::get<0>(r) /= repeat; + std::get<1>(r) /= repeat; + std::get<2>(r) /= repeat; + } + + return {results[0], results[1], results[2]}; +} + +void output_benchmark(CombinedResult const &combined_result, + std::string const &title) { + auto [path_pres, stratum_sync, cost_aware_stratum_sync] = combined_result; + std::cout << std::fixed << std::setprecision(3); + std::cout << "Benchmark for " << title << std::endl; + std::cout << "Technique | Work-Increase | Critical-Path-Increase | " + "Dependencies-Increase" + << std::endl; + std::cout << "Barrier Sync | " << std::get<0>(stratum_sync) << " | " + << std::get<1>(stratum_sync) << " | " << std::get<2>(stratum_sync) + << std::endl; + std::cout << "Cost Aware B.S. | " << std::get<0>(cost_aware_stratum_sync) + << " | " << std::get<1>(cost_aware_stratum_sync) << " | " + << std::get<2>(cost_aware_stratum_sync) << std::endl; + std::cout << "Path Preserving | " << std::get<0>(path_pres) << " | " + << std::get<1>(path_pres) << " | " << std::get<2>(path_pres) + << std::endl; + std::cout << std::endl; +} + +template +void bench_mark_given_graph(std::string title, + DiGraphView const &g, + D const &Dist, + N const &Noise = NoNoise(), + size_t repeat = REPEAT) { + output_benchmark(perform_benchmark_given_graph(g, Dist, Noise, repeat), + title); +} + +template +void bench_mark_given_graph_generator(std::string title, + G const &generator, + D const &Dist, + N const &Noise = NoNoise(), + size_t repeat = REPEAT) { + output_benchmark( + perform_benchmark_given_graph_generator(generator, Dist, Noise, repeat), + title); +} + +int main() { + { + DiGraph g = make_sample_dag_3(); + bench_mark_given_graph("sample_dag_3, Constant(1)", g, Constant(1)); + bench_mark_given_graph("sample_dag_3, Constant(1), UniformNoise(0.8, 1.25)", + g, + Constant(1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph("sample_dag_3, Constant(1), GaussianNoise(1, 0.1)", + g, + Constant(1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("sample_dag_3, Uniform(0,1)", g, Uniform(0, 1)); + bench_mark_given_graph( + "sample_dag_3, Uniform(0,1), UniformNoise(0.8, 1.25)", + g, + Uniform(0, 1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph("sample_dag_3, Uniform(0,1), GaussianNoise(1, 0.1)", + g, + Uniform(0, 1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("sample_dag_3, Binary(1, 80)", g, Binary(1, 80)); + bench_mark_given_graph( + "sample_dag_3, Binary(1, 80), UniformNoise(0.8, 1.25)", + g, + Binary(1, 80), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph("sample_dag_3, Binary(1, 80), GaussianNoise(1, 0.1)", + g, + Binary(1, 80), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("sample_dag_3, Chooser({1.0, 20.0, 500.0})", + g, + Chooser({1.0, 20.0, 500.0})); + bench_mark_given_graph( + "sample_dag_3, Chooser({1.0, 20.0, 500.0}), UniformNoise(0.8, 1.25)", + g, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "sample_dag_3, Chooser({1.0, 20.0, 500.0}), GaussianNoise(1, 0.1)", + g, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1)); + } + + { + DiGraph g = make_taso_nasnet_cell(); + bench_mark_given_graph("taso_nasnet_cell, Constant(1)", g, Constant(1)); + bench_mark_given_graph( + "taso_nasnet_cell, Constant(1), UniformNoise(0.8, 1.25)", + g, + Constant(1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "taso_nasnet_cell, Constant(1), GaussianNoise(1, 0.1)", + g, + Constant(1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("taso_nasnet_cell, Uniform(0,1)", g, Uniform(0, 1)); + bench_mark_given_graph( + "taso_nasnet_cell, Uniform(0,1), UniformNoise(0.8, 1.25)", + g, + Uniform(0, 1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "taso_nasnet_cell, Uniform(0,1), GaussianNoise(1, 0.1)", + g, + Uniform(0, 1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("taso_nasnet_cell, Binary(1, 80)", g, Binary(1, 80)); + bench_mark_given_graph( + "taso_nasnet_cell, Binary(1, 80), UniformNoise(0.8, 1.25)", + g, + Binary(1, 80), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "taso_nasnet_cell, Binary(1, 80), GaussianNoise(1, 0.1)", + g, + Binary(1, 80), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("taso_nasnet_cell, Chooser({1.0, 20.0, 500.0})", + g, + Chooser({1.0, 20.0, 500.0})); + bench_mark_given_graph("taso_nasnet_cell, Chooser({1.0, 20.0, 500.0}), " + "UniformNoise(0.8, 1.25)", + g, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "taso_nasnet_cell, Chooser({1.0, 20.0, 500.0}), GaussianNoise(1, 0.1)", + g, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1)); + } + + { + DiGraph g = make_parallel_chains(10, 5); + bench_mark_given_graph("parallel_chains, Constant(1)", g, Constant(1)); + bench_mark_given_graph( + "parallel_chains, Constant(1), UniformNoise(0.8, 1.25)", + g, + Constant(1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "parallel_chains, Constant(1), GaussianNoise(1, 0.1)", + g, + Constant(1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("parallel_chains, Uniform(0,1)", g, Uniform(0, 1)); + bench_mark_given_graph( + "parallel_chains, Uniform(0,1), UniformNoise(0.8, 1.25)", + g, + Uniform(0, 1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "parallel_chains, Uniform(0,1), GaussianNoise(1, 0.1)", + g, + Uniform(0, 1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("parallel_chains, Binary(1, 80)", g, Binary(1, 80)); + bench_mark_given_graph( + "parallel_chains, Binary(1, 80), UniformNoise(0.8, 1.25)", + g, + Binary(1, 80), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "parallel_chains, Binary(1, 80), GaussianNoise(1, 0.1)", + g, + Binary(1, 80), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("parallel_chains, Chooser({1.0, 20.0, 500.0})", + g, + Chooser({1.0, 20.0, 500.0})); + bench_mark_given_graph( + "parallel_chains, Chooser({1.0, 20.0, 500.0}), UniformNoise(0.8, 1.25)", + g, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "parallel_chains, Chooser({1.0, 20.0, 500.0}), GaussianNoise(1, 0.1)", + g, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1)); + } + + { + + auto generate_2_terminal_random_dag = []() { + return make_2_terminal_random_dag(60, .12, 5); + }; + size_t repeat = 100; + bench_mark_given_graph_generator("make_2_terminal_random_dag, Constant(1)", + generate_2_terminal_random_dag, + Constant(1), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Constant(1), UniformNoise(0.8, 1.25)", + generate_2_terminal_random_dag, + Constant(1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Constant(1), GaussianNoise(1, 0.1)", + generate_2_terminal_random_dag, + Constant(1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator("make_2_terminal_random_dag, Uniform(0,1)", + generate_2_terminal_random_dag, + Uniform(0, 1), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Uniform(0,1), UniformNoise(0.8, 1.25)", + generate_2_terminal_random_dag, + Uniform(0, 1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Uniform(0,1), GaussianNoise(1, 0.1)", + generate_2_terminal_random_dag, + Uniform(0, 1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Binary(1, 80)", + generate_2_terminal_random_dag, + Binary(1, 80), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Binary(1, 80), UniformNoise(0.8, 1.25)", + generate_2_terminal_random_dag, + Binary(1, 80), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Binary(1, 80), GaussianNoise(1, 0.1)", + generate_2_terminal_random_dag, + Binary(1, 80), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Chooser({1.0, 20.0, 500.0})", + generate_2_terminal_random_dag, + Chooser({1.0, 20.0, 500.0}), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Chooser({1.0, 20.0, 500.0}), " + "UniformNoise(0.8, 1.25)", + generate_2_terminal_random_dag, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Chooser({1.0, 20.0, 500.0}), " + "GaussianNoise(1, 0.1)", + generate_2_terminal_random_dag, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1), + repeat); + } + + { + size_t repeat = 100; + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Constant(1)", + generate_nasnet_bench_network, + Constant(1), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Constant(1), UniformNoise(0.8, 1.25)", + generate_nasnet_bench_network, + Constant(1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Constant(1), GaussianNoise(1, 0.1)", + generate_nasnet_bench_network, + Constant(1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Uniform(0,1)", + generate_nasnet_bench_network, + Uniform(0, 1), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Uniform(0,1), UniformNoise(0.8, 1.25)", + generate_nasnet_bench_network, + Uniform(0, 1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Uniform(0,1), GaussianNoise(1, 0.1)", + generate_nasnet_bench_network, + Uniform(0, 1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Binary(1, 80)", + generate_nasnet_bench_network, + Binary(1, 80), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Binary(1, 80), UniformNoise(0.8, 1.25)", + generate_nasnet_bench_network, + Binary(1, 80), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Binary(1, 80), GaussianNoise(1, 0.1)", + generate_nasnet_bench_network, + Binary(1, 80), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Chooser({1.0, 20.0, 500.0})", + generate_nasnet_bench_network, + Chooser({1.0, 20.0, 500.0}), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Chooser({1.0, 20.0, 500.0}), " + "UniformNoise(0.8, 1.25)", + generate_nasnet_bench_network, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Chooser({1.0, 20.0, 500.0}), " + "GaussianNoise(1, 0.1)", + generate_nasnet_bench_network, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1), + repeat); + } + + { + size_t repeat = 10; + DiGraph g = make_full_taso_nasnet(1, 1); + bench_mark_given_graph("make_full_taso_nasnet, Constant(1)", + g, + Constant(1), + NoNoise(), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Constant(1), UniformNoise(0.8, 1.25)", + g, + Constant(1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Constant(1), GaussianNoise(1, 0.1)", + g, + Constant(1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph("make_full_taso_nasnet, Uniform(0,1)", + g, + Uniform(0, 1), + NoNoise(), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Uniform(0,1), UniformNoise(0.8, 1.25)", + g, + Uniform(0, 1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Uniform(0,1), GaussianNoise(1, 0.1)", + g, + Uniform(0, 1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph("make_full_taso_nasnet, Binary(1, 80)", + g, + Binary(1, 80), + NoNoise(), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Binary(1, 80), UniformNoise(0.8, 1.25)", + g, + Binary(1, 80), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Binary(1, 80), GaussianNoise(1, 0.1)", + g, + Binary(1, 80), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph("make_full_taso_nasnet, Chooser({1.0, 20.0, 500.0})", + g, + Chooser({1.0, 20.0, 500.0}), + NoNoise(), + repeat); + bench_mark_given_graph("make_full_taso_nasnet, Chooser({1.0, 20.0, " + "500.0}), UniformNoise(0.8, 1.25)", + g, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph("make_full_taso_nasnet, Chooser({1.0, 20.0, " + "500.0}), GaussianNoise(1, 0.1)", + g, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1), + repeat); + } +} diff --git a/lib/utils/include/utils/containers/invert_map.h b/lib/utils/include/utils/containers/invert_map.h new file mode 100644 index 0000000000..6f0c04a189 --- /dev/null +++ b/lib/utils/include/utils/containers/invert_map.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INVERT_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INVERT_MAP_H + +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_map> + invert_map(std::unordered_map const &m) { + std::unordered_map> result; + for (auto const &[key, value] : m) { + result[value].insert(key); + } + return result; +} +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_set_of.h b/lib/utils/include/utils/containers/unordered_set_of.h index 722ae66d43..22ed323891 100644 --- a/lib/utils/include/utils/containers/unordered_set_of.h +++ b/lib/utils/include/utils/containers/unordered_set_of.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNIQUE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNIQUE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_SET_OF_H #include diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 3f170b5652..f130b17421 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -161,8 +161,6 @@ std::vector std::vector get_bfs_ordering(DiGraphView const &, std::unordered_set const &starting_points); -std::vector get_topological_ordering(DiGraphView const &); -std::vector get_unchecked_topological_ordering(DiGraphView const &); std::vector get_edge_topological_ordering(DiGraphView const &); // std::vector diff --git a/lib/utils/include/utils/graph/digraph/algorithms.h b/lib/utils/include/utils/graph/digraph/algorithms.h index 370f181c37..d1f6147383 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms.h +++ b/lib/utils/include/utils/graph/digraph/algorithms.h @@ -6,6 +6,7 @@ namespace FlexFlow { std::unordered_set get_edges(DiGraphView const &); +int num_edges(DiGraphView const &); std::unordered_set get_sources(DiGraphView const &); std::unordered_set get_sinks(DiGraphView const &); diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_descendants.h b/lib/utils/include/utils/graph/digraph/algorithms/get_descendants.h new file mode 100644 index 0000000000..18031b47eb --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_descendants.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_DESCENDANTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_DESCENDANTS + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set get_descendants(DiGraphView const &g, + Node const &starting_node); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h b/lib/utils/include/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h new file mode 100644 index 0000000000..c29cd6723c --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_LONGEST_PATH_LENGTHS_FROM_ROOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_LONGEST_PATH_LENGTHS_FROM_ROOT_H + +#include "utils/graph/digraph/digraph_view.h" +#include + +namespace FlexFlow { + +/** + * @brief Computes the longest path lengths from the root in a single source, + * directed acyclic graph. + * + * @return std::unordered_map For each node n, returns the length + * (i.e. number of nodes) of the longest path from the root to n. + * + * @note The root has a path length of 1. g must be both acyclic and have a + * single source. + */ +std::unordered_map + get_longest_path_lengths_from_root(DiGraphView const &g); + +/** + * @brief Computes the weighted longest path lengths from the root in a single + * source, directed acyclic graph. + * + * @return std::unordered_map For each node n, returns the length + * (i.e. the sum of the weights of all the nodes) of the longest path from the + * root to n. + * + * @note The root has a path length equal to its weight. g must be both acyclic + * and have a single source. + */ +std::unordered_map get_weighted_longest_path_lengths_from_root( + DiGraphView const &g, std::unordered_map const &node_costs); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.h b/lib/utils/include/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.h new file mode 100644 index 0000000000..db017c11da --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_TOPOLOGICAL_ORDERING_FROM_STARTING_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_TOPOLOGICAL_ORDERING_FROM_STARTING_NODE_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/node/node.dtg.h" + +namespace FlexFlow { + +/** + * @brief Returns a topologically ordered vector of nodes, with the topological + * traversal starting from the starting node. + * + * @note Nodes present within the graph that are not reachable by a traversal + * starting from the starting_node will not be included in the returned vector. + * g must be an acyclic graph + */ +std::vector + get_topological_ordering_from_starting_node(DiGraphView const &g, + Node const &starting_node); + +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_TOPOLOGICAL_ORDERING_FROM_STARTING_NODE_H diff --git a/lib/utils/include/utils/graph/digraph/algorithms/is_2_terminal_dag.h b/lib/utils/include/utils/graph/digraph/algorithms/is_2_terminal_dag.h new file mode 100644 index 0000000000..3b588c7984 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/is_2_terminal_dag.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_IS_2_TERMINAL_DAG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_IS_2_TERMINAL_DAG_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool is_2_terminal_dag(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/is_acyclic.h b/lib/utils/include/utils/graph/digraph/algorithms/is_acyclic.h index 909dc3aef4..ce63f75395 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/is_acyclic.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/is_acyclic.h @@ -5,7 +5,7 @@ namespace FlexFlow { -std::optional is_acyclic(DiGraphView const &); +bool is_acyclic(DiGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/serial_parallel/digraph_generation.h b/lib/utils/include/utils/graph/serial_parallel/digraph_generation.h new file mode 100644 index 0000000000..40fddc4c59 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/digraph_generation.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_DIGRAPH_GENERATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_DIGRAPH_GENERATION_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::unordered_map parallel_extend(DiGraph &g, + DiGraphView const &ext); +std::unordered_map serial_extend(DiGraph &g, + DiGraphView const &ext); +DiGraph serial_composition(DiGraphView const &g1, DiGraphView const &g2); +DiGraph parallel_composition(DiGraphView const &g1, DiGraphView const &g2); +DiGraph serial_composition(std::vector const &graphs); +DiGraph parallel_composition(std::vector const &graphs); + +DiGraph digraph_from_sp_decomposition(SerialParallelDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/normalize_sp_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/normalize_sp_decomposition.h new file mode 100644 index 0000000000..00a85a7514 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/normalize_sp_decomposition.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_normalize_sp_decomposition_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_normalize_sp_decomposition_H + +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +/** + * @brief Recursively normalizes a SerialParallelDecomposition. + * + * @details This function performs the following semantic substitutions: + * - Deletes every empty SerialSplit and ParallelSplit item, e.g., + * S(P(S()), Node(1), Node(2)) -> S(Node(1), Node(2)) + * + * - Replaces SerialSplit and ParallelSplit of size 1 with their content, e.g., + * S(S(Node(1)), P(Node(2))) -> S(Node(1), Node(2))) + * + */ +SerialParallelDecomposition + normalize_sp_decomposition(SerialParallelDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h index 7d8efc96f2..85d8917038 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h @@ -17,6 +17,19 @@ std::unordered_set get_nodes(SerialSplit const &); std::unordered_set get_nodes(ParallelSplit const &); std::unordered_set get_nodes(Node const &); +bool is_empty(Node const &node); +bool is_empty(SerialSplit const &serial); +bool is_empty(ParallelSplit const ¶llel); +bool is_empty(SerialParallelDecomposition const &sp); + +// duplicate nodes within `sp` are counted multiple times +size_t num_nodes(SerialParallelDecomposition const &sp); + +SerialParallelDecomposition serial_composition( + std::vector const &sp_compositions); +SerialParallelDecomposition parallel_composition( + std::unordered_set const &sp_compositions); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_metrics.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_metrics.h new file mode 100644 index 0000000000..7d2c546117 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_metrics.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_METRICS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_METRICS_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +std::unordered_map get_node_frequency_map(Node const &node); +std::unordered_map + get_node_frequency_map(SerialSplit const &serial); +std::unordered_map + get_node_frequency_map(ParallelSplit const ¶llel); +std::unordered_map + get_node_frequency_map(SerialParallelDecomposition const &sp); + +float work_cost(SerialParallelDecomposition const &sp, + std::unordered_map cost_map); + +float work_cost(DiGraphView const &g, + std::unordered_map const &cost_map); + +int num_dependencies(SerialParallelDecomposition const &sp); + +int num_dependencies(DiGraphView const &g); + +float critical_path_cost(SerialParallelDecomposition const &sp, + std::unordered_map const &cost_map); + +float critical_path_cost(DiGraphView const &g, + std::unordered_map const &cost_map); + +float relative_work_increase(DiGraphView const &g, + SerialParallelDecomposition const &sp, + std::unordered_map const &cost_map); + +float relative_critical_path_cost_increase( + DiGraphView const &g, + SerialParallelDecomposition const &sp, + std::unordered_map const &cost_map); + +float relative_num_dependencies_increase(DiGraphView const &g, + SerialParallelDecomposition const &sp); + +} // namespace FlexFlow + +#endif // FLEXFLOW_SERIAL_PARALLEL_METRICS_H diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h index 081137e513..6e4b0e33fb 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h +++ b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_SPLITS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_SPLITS_H #include "utils/graph/node/node.dtg.h" #include @@ -12,10 +12,11 @@ struct ParallelSplit; struct SerialSplit { public: - SerialSplit() = delete; + SerialSplit(); explicit SerialSplit(std::vector> const &); explicit SerialSplit( std::initializer_list> const &); + explicit SerialSplit(std::vector const &nodes); bool operator==(SerialSplit const &) const; bool operator!=(SerialSplit const &) const; @@ -46,11 +47,12 @@ namespace FlexFlow { struct ParallelSplit { public: - ParallelSplit() = delete; + ParallelSplit(); explicit ParallelSplit( std::unordered_set> const &); explicit ParallelSplit( std::initializer_list> const &); + explicit ParallelSplit(std::unordered_set const &nodes); bool operator==(ParallelSplit const &) const; bool operator!=(ParallelSplit const &) const; diff --git a/lib/utils/include/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h b/lib/utils/include/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h new file mode 100644 index 0000000000..7cd756b68e --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h @@ -0,0 +1,119 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIAL_PARALLEL_CRITICAL_PATH_PRESERVING_SP_IZATION_H +#define _FLEXFLOW_UTILS_GRAPH_SERIAL_PARALLEL_CRITICAL_PATH_PRESERVING_SP_IZATION_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +/** + * @brief Transforms a directed acyclic graph (DAG) into a Serial Parallel (SP) + * graph. The critical path cost is unchanged, and the SP-ization is done solely + * through node (work) duplication. + * + * @details + * The resulting graph, encoded as a SerialParallelDecomposition, is a tree + * whose critical path is the same as that of the original graph. The tree is + * constructed as follows: + * - Denote SP(n) as the SerialParallelDecomposition of the subgraph of g whose + * nodes are all the ancestors of n. + * - Denote the predecessors of n as M. + * - Then: + * - SP(n) = S(n, P({SP(m) for m in M})) + * - SP(root) = root + * - SP(sink) = SP(g) + * Where P, S represent parallel, serial composition respectively. + * + * Example: + * + * digraph G { + * n1 -> n2; + * n1 -> n3; + * n2 -> n4; + * n2 -> n5; + * n3 -> n5; + * n5 -> n6; + * n4 -> n6; + * } + * + * becomes + * + * digraph SP { + * n1 [label="n1"]; + * n2 [label="n2"]; + * n3 [label="n4"]; + * n4 [label="n6"]; + * n5 [label="n1"]; + * n6 [label="n2"]; + * n7 [label="n5"]; + * n8 [label="n1"]; + * n9 [label="n3"]; + * n1 -> n2; + * n2 -> n3; + * n3 -> n4; + * n5 -> n6; + * n6 -> n7; + * n7 -> n4; + * n8 -> n9; + * n9 -> n7; + * } + * + * + * @note g must be a 2 terminal (i.e. single source and single sink) directed + * acyclic graph. + */ +SerialParallelDecomposition + critical_path_preserving_sp_ization(DiGraphView const &g); + +/** + * @brief Transforms a directed acyclic graph (DAG) into a Serial Parallel (SP) + * graph with coalescing. The critical path cost is unchanged, and the + * SP-ization is done solely through node (work) duplication. + * + * @details + * This SP-ization technique, compared to the previous step, adds an additional + * coalescing step during parallel composition to reduce node duplication. The + * recursive formulation is equivalent, but the parallelization performs an + * additional coalescing step, where parallel strands with common heads are + * merged together. Example: P(S(1,2), S(1,3)) -> P(1, S(2,3)). + * + * Example: + * + * digraph G { + * n1 -> n2; + * n1 -> n3; + * n2 -> n4; + * n2 -> n5; + * n3 -> n5; + * n5 -> n6; + * n4 -> n6; + * } + * + * becomes + * + * digraph SP { + * n1 [label="n1"]; + * n2 [label="n2"]; + * n3 [label="n4"]; + * n4 [label="n6"]; + * n6 [label="n2"]; + * n7 [label="n5"]; + * n9 [label="n3"]; + * n1 -> n2; + * n2 -> n3; + * n3 -> n4; + * n1 -> n6; + * n6 -> n7; + * n7 -> n4; + * n1 -> n9; + * n9 -> n7; + * } + * + */ +SerialParallelDecomposition + critical_path_preserving_sp_ization_with_coalescing(DiGraphView const &g); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h b/lib/utils/include/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h new file mode 100644 index 0000000000..ffbf9bfcfe --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h @@ -0,0 +1,69 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIAL_PARALLEL_WORK_PRESERVING_SP_IZATION_H +#define _FLEXFLOW_UTILS_GRAPH_SERIAL_PARALLEL_WORK_PRESERVING_SP_IZATION_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +/** + * @brief + * Transforms a directed acyclic graph (DAG) into a Serial Parallel (SP) + * graph. The total number of nodes remains unchanged, and the SP-ization is + *done solely through edge (dependency) duplication. + * @details + * The graph is first partitioned into strata: the i_th stratum contains all the + *nodes whose critical path length has length i. The nodes in a given stratum + *are composed in parallel, and the strata are serially composed in succession. + * + * Example: + * + * + * digraph G { + * n1 -> n2; + * n1 -> n3; + * n2 -> n4; + * n2 -> n5; + * n3 -> n5; + * n5 -> n6; + * n4 -> n6; + * } + * becomes + * + * digraph SP { + * n1 -> n2; + * n1 -> n3; + * n2 -> n4; + * n2 -> n5; + * n3 -> n5; + * n4 -> n6; + * n5 -> n6; + * n4 -> n6; + * } + * + * @note g must be a directed acyclic graph. + **/ +SerialParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g); + +/** + * @brief + * Transforms a directed acyclic graph (DAG) into a Serial Parallel (SP) + * graph. The total number of nodes remains unchanged, and the SP-ization is + *done solely through edge (dependency) duplication. + * + * @details + * The algorithm operates under the same principles as + *`stratum_sync_sp_ization`: that is, a stratification step where the nodes are + *partitioned into strata, followed by a merging step where the strata are + *joined. The difference concerns the stratification step, which is cost-aware, + *so that the different disjoint subgraphs present within the same strata have a + *similar critical path cost, thus minimizing the overall critical path cost of + *the SP-ized graph. + **/ +SerialParallelDecomposition cost_aware_stratum_sync_sp_ization( + DiGraphView const &g, std::unordered_map const &cost_map); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/containers/invert_map.cc b/lib/utils/src/utils/containers/invert_map.cc new file mode 100644 index 0000000000..ca7308d4a0 --- /dev/null +++ b/lib/utils/src/utils/containers/invert_map.cc @@ -0,0 +1 @@ +#include "utils/containers/invert_map.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 323f444a22..e591c47161 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -13,6 +13,7 @@ #include "utils/graph/digraph/algorithms/get_incoming_edges.h" #include "utils/graph/digraph/algorithms/get_node_with_greatest_topo_rank.h" #include "utils/graph/digraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms.cc b/lib/utils/src/utils/graph/digraph/algorithms.cc index 8cd685e5c6..78afc994e8 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms.cc @@ -15,6 +15,10 @@ std::unordered_set get_edges(DiGraphView const &g) { return g.query_edges(directed_edge_query_all()); } +int num_edges(DiGraphView const &g) { + return get_edges(g).size(); +} + std::unordered_set get_sinks(DiGraphView const &g) { return get_sources(flipped(g)); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc new file mode 100644 index 0000000000..09e0c2262c --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc @@ -0,0 +1,30 @@ +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/containers/contains.h" +#include "utils/containers/filter.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { +std::unordered_set get_descendants(DiGraphView const &g, + Node const &starting_node) { + assert(is_acyclic(g)); + std::unordered_set descendants; + std::stack to_visit; + for (Node const &successor : get_successors(g, starting_node)) { + to_visit.push(successor); + } + while (!to_visit.empty()) { + Node current = to_visit.top(); + to_visit.pop(); + descendants.insert(current); + for (auto const &s : filter(get_successors(g, current), [&](Node const &n) { + return !contains(descendants, n); + })) { + to_visit.push(s); + } + } + return descendants; +}; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc new file mode 100644 index 0000000000..6e4dfe95aa --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc @@ -0,0 +1,57 @@ +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/containers.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include + +// change to have unweigthed call the weighted version + +namespace FlexFlow { + +std::unordered_map get_weighted_longest_path_lengths_from_root( + DiGraphView const &g, std::unordered_map const &node_costs) { + + assert(is_acyclic(g)); + assert(get_sources(g).size() == 1); + + std::vector topo_order = get_topological_ordering(g); + std::unordered_map longest_path_lengths; + + for (Node const &n : topo_order) { + std::unordered_set predecessor_path_lengths = + transform(get_predecessors(g, n), [&](Node const &pred) { + return longest_path_lengths.at(pred); + }); + longest_path_lengths[n] = + (predecessor_path_lengths.size() == 0) + ? node_costs.at(n) + : maximum(predecessor_path_lengths) + node_costs.at(n); + } + return longest_path_lengths; +} + +std::unordered_map + get_longest_path_lengths_from_root(DiGraphView const &g) { + + assert(is_acyclic(g)); + assert(get_sources(g).size() == 1); + + std::vector topo_order = get_topological_ordering(g); + std::unordered_map longest_path_lengths; + + for (Node const &n : topo_order) { + std::unordered_set predecessor_path_lengths = + transform(get_predecessors(g, n), [&](Node const &pred) { + return longest_path_lengths.at(pred); + }); + longest_path_lengths[n] = (predecessor_path_lengths.size() == 0) + ? 1 + : maximum(predecessor_path_lengths) + 1; + } + + return longest_path_lengths; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc index 41fe3b67d5..fc0646ce5e 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -1,6 +1,8 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/is_acyclic.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/traversal.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.cc new file mode 100644 index 0000000000..40a2fd46ed --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.cc @@ -0,0 +1,29 @@ +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/traversal.h" + +namespace FlexFlow { + +static std::vector get_unchecked_topological_ordering_from_starting_node( + DiGraphView const &g, Node const &starting_node) { + + std::unordered_set descendants = get_descendants(g, starting_node); + descendants.insert(starting_node); + return get_topological_ordering(get_subgraph(g, descendants)); +} + +std::vector + get_topological_ordering_from_starting_node(DiGraphView const &g, + Node const &starting_node) { + assert(is_acyclic(g)); + return get_unchecked_topological_ordering_from_starting_node(g, + starting_node); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_2_terminal_dag.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_2_terminal_dag.cc new file mode 100644 index 0000000000..84b95a6c38 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_2_terminal_dag.cc @@ -0,0 +1,11 @@ +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" + +namespace FlexFlow { + +bool is_2_terminal_dag(DiGraphView const &g) { + return (is_acyclic(g) && (get_sources(g).size() == 1) && + get_sinks(g).size() == 1); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc index dd660f193d..c26cf70ebc 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -1,30 +1,52 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" -#include "utils/graph/digraph/algorithms.h" +#include "utils/containers/generate_map.h" +#include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/traversal.h" +#include namespace FlexFlow { -std::optional is_acyclic(DiGraphView const &g) { +enum class ExplorationStatus { NOT_EXPLORED, BEING_EXPLORED, FULLY_EXPLORED }; + +bool is_acyclic(DiGraphView const &g) { if (num_nodes(g) == 0) { - return std::nullopt; - } - std::unordered_set sources = get_sources(g); - if (sources.size() == 0) { - return false; + return true; // vacuously true } - auto dfs_view = unchecked_dfs(g, sources); - std::unordered_set seen; - for (unchecked_dfs_iterator it = dfs_view.begin(); it != dfs_view.end(); - it++) { - if (contains(seen, *it)) { - return false; - } else { - seen.insert(*it); + + std::unordered_map status = + generate_map(get_nodes(g), [](Node const &n) { + return ExplorationStatus::NOT_EXPLORED; + }); + + // recursively explore a given node and all its successors: if, while + // exploring, we find a node that was already being explored, then there is a + // cycle + std::function cycle_downstream_from_node = + [&](Node const &n) -> bool { + status[n] = ExplorationStatus::BEING_EXPLORED; + + for (Node const &successor : get_successors(g, n)) { + if (status.at(successor) == ExplorationStatus::NOT_EXPLORED) { + if (cycle_downstream_from_node( + successor)) { // one of the descendants is part of a cycle + return true; + } + } else if (status.at(successor) == ExplorationStatus::BEING_EXPLORED) { + return true; // we're exploring a node we were already exploring: we + // have hit a cycle + } } - } - if (seen != get_nodes(g)) { + + status[n] = ExplorationStatus::FULLY_EXPLORED; return false; + }; + + for (Node const &node : get_nodes(g)) { + if (status.at(node) == ExplorationStatus::NOT_EXPLORED) { + if (cycle_downstream_from_node(node)) { + return false; + } + } } return true; } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc index 10ffe4fc33..6205dfc225 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,7 +1,11 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/containers/contains.h" #include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" #include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -23,30 +27,25 @@ DirectedEdgeMaskView *DirectedEdgeMaskView::clone() const { return new DirectedEdgeMaskView(this->g, this->edge_mask); } -DiGraphView transitive_reduction(DiGraphView const &g) { +DiGraphView unchecked_transitive_reduction(DiGraphView const &g) { std::unordered_set edge_mask = get_edges(g); - - while (true) { - std::unordered_set new_edge_mask = edge_mask; - for (DirectedEdge const &e1 : edge_mask) { - for (DirectedEdge const &e2 : edge_mask) { - if (e1.dst == e2.src && e1 != e2) { - DirectedEdge trans_edge = DirectedEdge{e1.src, e2.dst}; - if (contains(new_edge_mask, trans_edge)) { - new_edge_mask.erase(trans_edge); - } - } + std::unordered_set nodes = get_nodes(g); + for (Node const &n1 : nodes) { + for (Node const &n2 : get_descendants(g, n1)) { + for (Node const &n3 : get_descendants(g, n2)) { + // if there is a path from n1 to n2, and a path from n2 to n3, edge + // {n1,n3} is redundant + // if edge {n1, n3} does not exist, this is a no-op + edge_mask.erase(DirectedEdge{n1, n3}); } } - - if (new_edge_mask == edge_mask) { - break; - } else { - edge_mask = new_edge_mask; - } } - return DiGraphView::create(g, edge_mask); } +DiGraphView transitive_reduction(DiGraphView const &g) { + assert(is_acyclic(g)); + return unchecked_transitive_reduction(g); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/digraph_generation.cc b/lib/utils/src/utils/graph/serial_parallel/digraph_generation.cc new file mode 100644 index 0000000000..b1b572d676 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/digraph_generation.cc @@ -0,0 +1,100 @@ +#include "utils/graph/serial_parallel/digraph_generation.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/transform.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/serial_parallel/serial_parallel_splits.h" + +namespace FlexFlow { + +std::unordered_map parallel_extend(DiGraph &g, + DiGraphView const &ext) { + std::unordered_map node_map; + for (Node const &node : get_nodes(ext)) { + node_map.emplace(node, g.add_node()); + } + for (DirectedEdge const &edge : get_edges(ext)) { + g.add_edge(DirectedEdge{node_map.at(edge.src), node_map.at(edge.dst)}); + } + return node_map; +} + +std::unordered_map serial_extend(DiGraph &g, + DiGraphView const &ext) { + std::unordered_set original_sinks = get_sinks(g); + std::unordered_map node_map = parallel_extend(g, ext); + for (Node const &node1 : original_sinks) { + for (Node const &node2 : get_sources(ext)) { + g.add_edge(DirectedEdge{node1, node_map.at(node2)}); + } + } + return node_map; +} + +DiGraph serial_composition(DiGraphView const &g1, DiGraphView const &g2) { + DiGraph g = materialize_digraph_view(g1); + serial_extend(g, g2); + return g; +} + +DiGraph parallel_composition(DiGraphView const &g1, DiGraphView const &g2) { + DiGraph g = materialize_digraph_view(g1); + parallel_extend(g, g2); + return g; +} + +DiGraph serial_composition(std::vector const &graphs) { + DiGraph g = DiGraph::create(); + for (DiGraphView const &gs : graphs) { + g = materialize_digraph_view(serial_composition(g, gs)); + } + return g; +} + +// TODO(@pietro): should be std::unordered_set, but DiGraphs are +// currently non-hashable +DiGraph parallel_composition(std::vector const &graphs) { + DiGraph g = DiGraph::create(); + for (DiGraphView const &gs : graphs) { + g = materialize_digraph_view(parallel_composition(g, gs)); + } + return g; +} + +DiGraph digraph_from_sp_decomposition(Node const &node) { + DiGraph g = DiGraph::create(); + g.add_node(); + return g; +} + +DiGraph digraph_from_sp_decomposition(SerialSplit const &serial) { + std::vector children = + transform(serial.children, [](auto const &child) { + return widen(child); + }); + return serial_composition( + transform(children, [](auto const child) -> DiGraphView { + return digraph_from_sp_decomposition(child); + })); +} + +DiGraph digraph_from_sp_decomposition(ParallelSplit const ¶llel) { + std::vector children = + transform(as_vector(parallel.children), [](auto const &child) { + return widen(child); + }); + return parallel_composition( + transform(children, [](auto const child) -> DiGraphView { + return digraph_from_sp_decomposition(child); + })); +} + +DiGraph digraph_from_sp_decomposition(SerialParallelDecomposition const &sp) { + return sp.visit( + [](auto const &x) { return digraph_from_sp_decomposition(x); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/normalize_sp_decomposition.cc b/lib/utils/src/utils/graph/serial_parallel/normalize_sp_decomposition.cc new file mode 100644 index 0000000000..47ec908fe4 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/normalize_sp_decomposition.cc @@ -0,0 +1,56 @@ +#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" +#include "utils/containers/filter.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/variant.h" + +namespace FlexFlow { + +template +static auto filter_empty(T const &container) { + return filter(container, [](auto const &child) { + return !is_empty(widen(child)); + }); +} + +SerialParallelDecomposition normalize_sp_decomposition(Node const &node) { + return SerialParallelDecomposition(node); +} + +SerialParallelDecomposition + normalize_sp_decomposition(SerialSplit const &serial) { + std::vector normalized_children = + transform(filter_empty(serial.children), [](auto const &child) { + return normalize_sp_decomposition( + widen(child)); + }); + + if (normalized_children.size() == 1) { + return get_only(normalized_children); + } + return serial_composition(normalized_children); +} + +SerialParallelDecomposition + normalize_sp_decomposition(ParallelSplit const ¶llel) { + std::unordered_set normalized_children = + transform(filter_empty(parallel.children), [](auto const &child) { + return normalize_sp_decomposition( + widen(child)); + }); + + if (normalized_children.size() == 1) { + return get_only(normalized_children); + } + return parallel_composition(normalized_children); +} + +SerialParallelDecomposition + normalize_sp_decomposition(SerialParallelDecomposition const &sp) { + return sp.visit( + [](auto const &x) { return normalize_sp_decomposition(x); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc index 666bf40f10..b1c0f058cc 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc @@ -1,8 +1,14 @@ #include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/containers.h" +#include "utils/containers/all_of.h" +#include "utils/containers/extend.h" +#include "utils/containers/get_only.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" #include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/serial_parallel/serial_parallel_metrics.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" @@ -71,4 +77,61 @@ std::unordered_set get_nodes(Node const &node) { return {node}; } +bool is_empty(Node const &node) { + return false; +} + +bool is_empty(SerialSplit const &serial) { + return all_of(serial.children, [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(ParallelSplit const ¶llel) { + return all_of(parallel.children, [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(SerialParallelDecomposition const &sp) { + return sp.visit([](auto const &t) { return is_empty(t); }); +} + +size_t num_nodes(SerialParallelDecomposition const &sp) { + return sum(values(get_node_frequency_map(sp))); +} + +SerialParallelDecomposition serial_composition( + std::vector const &sp_compositions) { + SerialSplit composition{}; + for (SerialParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + extend(composition.children, sp_comp.get().children); + } else if (sp_comp.has()) { + composition.children.push_back(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.children.push_back(sp_comp.get()); + } + } + return SerialParallelDecomposition(composition); +} + +SerialParallelDecomposition parallel_composition( + std::unordered_set const &sp_compositions) { + ParallelSplit composition{}; + for (SerialParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + composition.children = set_union(composition.children, + sp_comp.get().children); + } else if (sp_comp.has()) { + composition.children.insert(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.children.insert(sp_comp.get()); + } + } + return SerialParallelDecomposition(composition); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_metrics.cc b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_metrics.cc new file mode 100644 index 0000000000..e25a8dc80a --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_metrics.cc @@ -0,0 +1,126 @@ +#include "utils/graph/serial_parallel/serial_parallel_metrics.h" +#include "utils/containers.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/serial_parallel/digraph_generation.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +std::unordered_map get_node_frequency_map(Node const &node) { + return {{node, 1}}; +} + +std::unordered_map + get_node_frequency_map(ParallelSplit const ¶llel) { + std::unordered_map counter; + for (std::variant const &child : parallel.children) { + for (auto const &[node, count] : + get_node_frequency_map(widen(child))) { + counter[node] += count; + } + } + return counter; +} + +std::unordered_map + get_node_frequency_map(SerialSplit const &serial) { + std::unordered_map counter; + for (std::variant const &child : serial.children) { + for (auto const &[node, count] : + get_node_frequency_map(widen(child))) { + counter[node] += count; + } + } + return counter; +} + +std::unordered_map + get_node_frequency_map(SerialParallelDecomposition const &sp) { + return sp.visit>( + [](auto const &t) { return get_node_frequency_map(t); }); +} + +float work_cost(SerialParallelDecomposition const &sp, + std::unordered_map cost_map) { + auto cost_per_node_group = [&](std::pair const &pair) { + return pair.second * cost_map.at(pair.first); + }; + std::unordered_map counter = get_node_frequency_map(sp); + std::vector> pairs(counter.cbegin(), counter.cend()); + return sum(transform(pairs, cost_per_node_group)); +} + +float work_cost(DiGraphView const &g, + std::unordered_map const &cost_map) { + return sum(transform(as_vector(get_nodes(g)), + [&](Node const &node) { return cost_map.at(node); })); +} + +float critical_path_cost(Node const &node, + std::unordered_map const &cost_map) { + return cost_map.at(node); +} + +float critical_path_cost(SerialSplit const &serial, + std::unordered_map const &cost_map) { + return sum(transform( + serial.children, [&](std::variant const &child) { + return critical_path_cost(widen(child), + cost_map); + })); +} + +float critical_path_cost(ParallelSplit const ¶llel, + std::unordered_map const &cost_map) { + return maximum(transform( + parallel.children, [&](std::variant const &child) { + return critical_path_cost(widen(child), + cost_map); + })); +} + +float critical_path_cost(SerialParallelDecomposition const &sp, + std::unordered_map const &cost_map) { + return sp.visit( + [&](auto const &t) { return critical_path_cost(t, cost_map); }); +} + +float critical_path_cost(DiGraphView const &g, + std::unordered_map const &cost_map) { + return maximum( + values(get_weighted_longest_path_lengths_from_root(g, cost_map))); +} + +int num_dependencies(SerialParallelDecomposition const &sp) { + return num_dependencies(digraph_from_sp_decomposition(sp)); +} + +int num_dependencies(DiGraphView const &g) { + return num_edges(g); +} + +float relative_work_increase(DiGraphView const &g, + SerialParallelDecomposition const &sp, + std::unordered_map const &cost_map) { + return work_cost(sp, cost_map) / work_cost(g, cost_map); +} + +float relative_critical_path_cost_increase( + DiGraphView const &g, + SerialParallelDecomposition const &sp, + std::unordered_map const &cost_map) { + return critical_path_cost(sp, cost_map) / critical_path_cost(g, cost_map); +} + +float relative_num_dependencies_increase( + DiGraphView const &g, SerialParallelDecomposition const &sp) { + return static_cast(num_dependencies(sp)) / num_dependencies(g); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc index 8fa42d4b22..1dfda56e5b 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc +++ b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc @@ -1,4 +1,5 @@ #include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/containers/transform.h" #include "utils/fmt/unordered_set.h" #include "utils/fmt/variant.h" #include "utils/fmt/vector.h" @@ -8,6 +9,8 @@ namespace FlexFlow { +SerialSplit::SerialSplit() : children{} {} + SerialSplit::SerialSplit( std::vector> const &children) : children(children) {} @@ -16,6 +19,11 @@ SerialSplit::SerialSplit( std::initializer_list> const &children) : children(children) {} +SerialSplit::SerialSplit(std::vector const &nodes) + : children(transform(nodes, [](Node const &node) { + return std::variant(node); + })) {} + bool SerialSplit::operator==(SerialSplit const &other) const { return this->tie() == other.tie(); } @@ -36,6 +44,8 @@ std::ostream &operator<<(std::ostream &s, SerialSplit const &split) { return s << fmt::to_string(split); } +ParallelSplit::ParallelSplit() : children{} {} + ParallelSplit::ParallelSplit( std::unordered_set> const &children) : children(children) {} @@ -44,6 +54,11 @@ ParallelSplit::ParallelSplit( std::initializer_list> const &children) : children(children) {} +ParallelSplit::ParallelSplit(std::unordered_set const &nodes) + : children(transform(nodes, [](Node const &node) { + return std::variant(node); + })) {} + bool ParallelSplit::operator==(ParallelSplit const &other) const { return this->tie() == other.tie(); } diff --git a/lib/utils/src/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.cc b/lib/utils/src/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.cc new file mode 100644 index 0000000000..4f7c8bc87e --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.cc @@ -0,0 +1,119 @@ +#include "utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/variant.h" + +namespace FlexFlow { + +static SerialSplit cut_off_head(SerialSplit const &s) { + assert(s.children.size() > 0); + return SerialSplit{std::vector>( + s.children.begin() + 1, s.children.end())}; +} + +/* Performs a parallel composition with coalescing, where components with a + * common starting child are merged together + * Example: to parallel compose S(1, 2, 5), S(1, 3, 4): + * without coalescing: P(S(1, 2, 5), S(1, 3, 4)) + * with coalescing: S(1, P( S(2,5), S(3,4) )) + */ +static SerialParallelDecomposition parallel_composition_with_coalescing( + std::unordered_set const &strands) { + if (strands.size() == 1) { + return SerialParallelDecomposition(get_only(strands)); + } + + // group strands by their first ("head") node + std::unordered_map, + std::unordered_set> + grouped_strands; + for (SerialSplit predecessor : filter(strands, [](SerialSplit const &serial) { + return !is_empty(serial); + })) { + grouped_strands[predecessor.children.at(0)].insert( + cut_off_head(predecessor)); + } + + // recursively coalesce the strands + std::unordered_set coalesced_strands; + for (auto const &[head, tails] : grouped_strands) { + SerialParallelDecomposition parallel_comp = + parallel_composition_with_coalescing(tails); + coalesced_strands.insert(serial_composition( + {widen(head), parallel_comp})); + } + + return normalize_sp_decomposition(parallel_composition(coalesced_strands)); +} + +static SerialParallelDecomposition + critical_path_preserving_sp_ization_unchecked_with_coalescing( + DiGraphView const &g) { + std::unordered_map node_to_sp; + + Node source = get_only(get_sources(g)); + node_to_sp[source] = SerialSplit{{source}}; + + for (Node const &node : get_topological_ordering(g)) { + if (node == source) { + continue; + } + std::unordered_set predecessors_as_sp = + transform(get_predecessors(g, node), + [&](Node const &p) { return node_to_sp.at(p); }); + + SerialParallelDecomposition parallel_composed_predecessors = + parallel_composition_with_coalescing(predecessors_as_sp); + SerialParallelDecomposition sp_decomp = serial_composition( + {parallel_composed_predecessors, SerialParallelDecomposition(node)}); + node_to_sp[node] = sp_decomp.get(); + } + + Node sink = get_only(get_sinks(g)); + return normalize_sp_decomposition( + SerialParallelDecomposition(node_to_sp.at(sink))); +} + +SerialParallelDecomposition + critical_path_preserving_sp_ization_with_coalescing(DiGraphView const &g) { + assert(is_2_terminal_dag(g)); + return critical_path_preserving_sp_ization_unchecked_with_coalescing(g); +} + +static SerialParallelDecomposition + critical_path_preserving_sp_ization_unchecked(DiGraphView const &g) { + std::unordered_map node_to_sp; + + for (Node const &node : get_topological_ordering(g)) { + + std::unordered_set predecessors_as_sp = + transform(get_predecessors(g, node), + [&](Node const &p) { return node_to_sp.at(p); }); + + SerialParallelDecomposition sp_decomp = serial_composition( + {normalize_sp_decomposition(parallel_composition(predecessors_as_sp)), + SerialParallelDecomposition(node)}); + + node_to_sp.emplace(node, sp_decomp); + } + + Node sink = get_only(get_sinks(g)); + return node_to_sp.at(sink); +} + +SerialParallelDecomposition + critical_path_preserving_sp_ization(DiGraphView const &g) { + assert(is_2_terminal_dag(g)); + return critical_path_preserving_sp_ization_unchecked(g); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.cc b/lib/utils/src/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.cc new file mode 100644 index 0000000000..fe1c87f6bd --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.cc @@ -0,0 +1,190 @@ +#include "utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h" +#include "utils/containers.h" +#include "utils/containers/all_of.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/get_only.h" +#include "utils/containers/invert_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/sorted.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.h" +#include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/serial_parallel/serial_parallel_metrics.h" +#include "utils/hash/unordered_set.h" +#include "utils/hash/vector.h" +#include +#include + +namespace FlexFlow { + +std::vector> + stratum_split_assuming_unit_cost(DiGraphView const &g) { + std::unordered_map node_to_stratum = + get_longest_path_lengths_from_root(g); + std::vector> result( + maximum(values(node_to_stratum))); + for (auto const &[node, depth] : node_to_stratum) { + result[depth - 1].insert(node); + } + return result; +} + +static SerialParallelDecomposition + naive_stratum_merge(std::vector> stratum_split) { + std::vector strata = transform( + stratum_split, [](std::unordered_set const &stratum_nodes) { + return SerialParallelDecomposition(ParallelSplit{stratum_nodes}); + }); + return normalize_sp_decomposition(serial_composition(strata)); +} + +SerialParallelDecomposition + stratum_sync_sp_ization_unchecked(DiGraphView const &g) { + + std::vector> stratum_split = + stratum_split_assuming_unit_cost(g); + return naive_stratum_merge(stratum_split); +} + +SerialParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g) { + assert(is_acyclic(g)); + return stratum_sync_sp_ization_unchecked(g); +} + +static std::unordered_set get_heads( + DiGraphView const &g, + std::unordered_set> previous_stratum_metanodes, + std::unordered_set explored) { + std::unordered_set previous_stratum_nodes = + set_union(previous_stratum_metanodes); + std::unordered_set candidate_heads = + set_union(values(get_successors(g, previous_stratum_nodes))); + + auto is_valid_head = [&](Node const &n) { + return (!contains(explored, n) && + all_of(get_predecessors(g, n), + [&](Node const &p) { return contains(explored, p); })); + }; + + return filter(candidate_heads, is_valid_head); +} + +// returns a set of filtered topological orderings starting from `heads` such +// that all nodes present in multiple orderings are not included +static std::unordered_set> + get_non_overlapping_topological_orderings( + DiGraphView const &g, std::unordered_set const &heads) { + + std::unordered_set> topo_orderings = + transform(heads, [&](Node const &head) { + return get_topological_ordering_from_starting_node(g, head); + }); + + std::unordered_map node_frequency; + for (std::vector const &ordering : topo_orderings) { + for (Node const &node : ordering) { + node_frequency[node]++; + } + } + + std::unordered_set visitable_nodes = + filter(keys(node_frequency), + [&](Node const &n) { return node_frequency.at(n) == 1; }); + + std::unordered_set> non_overlapping_topo_orderings = + transform(topo_orderings, [&](std::vector const &ordering) { + return filter(ordering, [&](Node const &n) { + return contains(visitable_nodes, n); + }); + }); + + return non_overlapping_topo_orderings; +} + +static std::unordered_set> + get_metanodes(DiGraphView const &g, + std::unordered_set> const &topo_orderings, + float stratum_cost, + std::unordered_map const &cost_map) { + + auto get_metanode = [&](std::vector const &topo_ordering) { + std::unordered_set explored_nodes; + for (Node const &node : topo_ordering) { + float metanode_cost = + critical_path_cost(stratum_sync_sp_ization(get_subgraph( + g, set_union(explored_nodes, {node}))), + cost_map); + if (metanode_cost > stratum_cost * 1.01) { + break; + } + explored_nodes.insert(node); + } + return explored_nodes; + }; + + return transform(topo_orderings, get_metanode); +} + +static std::vector>> + cost_aware_stratum_split(DiGraphView const &g, + std::unordered_map const &cost_map) { + std::vector>> strata; + Node source = get_only(get_sources(g)); + std::unordered_set explored = {source}; + strata.push_back({{source}}); + while (get_nodes(g) != explored) { + + std::unordered_set heads = get_heads(g, strata.back(), explored); + std::unordered_set> non_overlapping_topo_orderings = + get_non_overlapping_topological_orderings(g, heads); + float stratum_cost = maximum( + transform(heads, [&](Node const &n) { return cost_map.at(n); })); + std::unordered_set> metanodes = get_metanodes( + g, non_overlapping_topo_orderings, stratum_cost, cost_map); + strata.push_back(metanodes); + + explored = set_union(explored, set_union(metanodes)); + } + return strata; +} + +SerialParallelDecomposition cost_aware_stratum_sync_sp_ization_unchecked( + DiGraphView const &g, std::unordered_map const &cost_map) { + if (get_nodes(g).size() == 1) { + return SerialParallelDecomposition(get_only(get_nodes(g))); + } + std::vector>> stratum_split = + cost_aware_stratum_split(g, cost_map); + std::vector> sp_ized_strata; + for (auto const &stratum : stratum_split) { + auto sp_ized_stratum = + transform(stratum, [&](std::unordered_set const &nodes) { + return cost_aware_stratum_sync_sp_ization_unchecked( + get_subgraph(g, nodes), cost_map); + }); + sp_ized_strata.push_back(sp_ized_stratum); + } + + return normalize_sp_decomposition( + serial_composition(transform(sp_ized_strata, parallel_composition))); +} + +SerialParallelDecomposition cost_aware_stratum_sync_sp_ization( + DiGraphView const &g, std::unordered_map const &cost_map) { + assert(is_acyclic(g)); + return cost_aware_stratum_sync_sp_ization_unchecked(g, cost_map); +} + +} // namespace FlexFlow diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index a1dd75504e..2148b13a0c 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -64,13 +64,14 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); SUBCASE("get_imm_dominators") { - std::unordered_map> result = get_imm_dominators(g); + std::unordered_map> result = + get_imm_dominators(g); - std::unordered_map> expected_result = { + std::unordered_map> expected_result = { {n[2], n[0]}, {n[1], n[0]}, {n[3], n[0]}, - {n[0], nullopt}, + {n[0], std::nullopt}, }; CHECK(result == expected_result); } @@ -138,7 +139,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nonlinear") { g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + CHECK(is_acyclic(g) == true); } SUBCASE("not connected") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root_node.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root_node.cc new file mode 100644 index 0000000000..861a36d85b --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root_node.cc @@ -0,0 +1,60 @@ +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_longest_path_lengths_from_root - linear graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + std::vector edges = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}, + DirectedEdge{n[3], n[4]}, + }; + + add_edges(g, edges); + + std::unordered_map expected_lengths = { + {n[0], 1}, + {n[1], 2}, + {n[2], 3}, + {n[3], 4}, + {n[4], 5}, + }; + + CHECK(get_longest_path_lengths_from_root(g) == expected_lengths); + } + + TEST_CASE("get_longest_path_lengths_from_root - more complex graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 7); + std::vector edges = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[0], n[4]}, + DirectedEdge{n[0], n[6]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + DirectedEdge{n[5], n[6]}}; + + add_edges(g, edges); + + std::unordered_map expected_lengths = { + {n[0], 1}, + {n[1], 2}, + {n[2], 3}, + {n[3], 4}, + {n[4], 2}, + {n[5], 5}, + {n[6], 6}, + }; + + CHECK(get_longest_path_lengths_from_root(g) == expected_lengths); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index fd2f469f93..ac4057c3b2 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -4,6 +4,7 @@ #include "utils/containers/transform.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc index 3ad506f40a..f07a5405a9 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -8,23 +8,107 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("is_acyclic") { DiGraph g = DiGraph::create(); + SUBCASE("empty graph") { + CHECK(is_acyclic(g)); + } - std::vector n = add_nodes(g, 6); + SUBCASE("single node") { + add_nodes(g, 1); + CHECK(is_acyclic(g)); + } - add_edges(g, - { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(1), n.at(5)}, - DirectedEdge{n.at(2), n.at(4)}, - DirectedEdge{n.at(3), n.at(1)}, - DirectedEdge{n.at(3), n.at(4)}, - }); + SUBCASE("simple acyclic graph") { + std::vector n = add_nodes(g, 3); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + }); + CHECK(is_acyclic(g)); + } - std::optional correct = false; - std::optional result = is_acyclic(g); + SUBCASE("simple cyclic graph") { + std::vector n = add_nodes(g, 3); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + }); + CHECK_FALSE(is_acyclic(g)); + } - CHECK(result == correct); + SUBCASE("2 parallel chains") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + CHECK(is_acyclic(g)); + } + SUBCASE("traversal with root") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(2)}}); + CHECK_FALSE(is_acyclic(g)); + } + + SUBCASE("traversal without root") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}}); + CHECK_FALSE(is_acyclic(g)); + } + + SUBCASE("traversal nonlinear") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}}); + CHECK(is_acyclic(g)); + } + + SUBCASE("complex cyclic graph") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + DirectedEdge{n.at(5), n.at(1)}, + }); + CHECK_FALSE(is_acyclic(g)); + } + + SUBCASE("complex cyclic graph #2") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(1)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + CHECK_FALSE(is_acyclic(g)); + } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc index b8a35346f4..0c00357ccb 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -38,6 +38,50 @@ TEST_SUITE(FF_TEST_SUITE) { } } + SUBCASE("linear graph with additional edge") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(0), n.at(3)}, + }); + + DiGraphView result = transitive_reduction(g); + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + }; + CHECK(result_edges == correct_edges); + } + + SUBCASE("linear graph with 2 additional edges") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + }); + + DiGraphView result = transitive_reduction(g); + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + SUBCASE("nontrivial graph") { // from // https://en.wikipedia.org/w/index.php?title=Transitive_reduction&oldid=1226082357#In_directed_acyclic_graphs diff --git a/lib/utils/test/src/utils/graph/serial_parallel/digraph_generation.cc b/lib/utils/test/src/utils/graph/serial_parallel/digraph_generation.cc new file mode 100644 index 0000000000..07f080e180 --- /dev/null +++ b/lib/utils/test/src/utils/graph/serial_parallel/digraph_generation.cc @@ -0,0 +1,110 @@ +#include "utils/graph/serial_parallel/digraph_generation.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("digraph_from_sp_decomposition") { + SUBCASE("Empty") { + SerialParallelDecomposition input = + SerialParallelDecomposition(ParallelSplit{}); + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 0); + CHECK(num_edges(result) == 0); + } + SUBCASE("Complex Empty") { + SerialParallelDecomposition input = SerialParallelDecomposition( + ParallelSplit{SerialSplit{}, SerialSplit{ParallelSplit{}}}); + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 0); + CHECK(num_edges(result) == 0); + } + + SUBCASE("Single Node") { + SerialParallelDecomposition input = SerialParallelDecomposition(Node(1)); + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 1); + CHECK(num_edges(result) == 0); + } + + SUBCASE("Simple SerialSplit") { + SerialParallelDecomposition input = + SerialParallelDecomposition{SerialSplit{Node(1), Node(2), Node(3)}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 3); + CHECK(num_edges(result) == 2); + CHECK(get_sources(result).size() == 1); + CHECK(get_sinks(result).size() == 1); + } + + SUBCASE("Simple ParallelSplit") { + SerialParallelDecomposition input = + SerialParallelDecomposition{ParallelSplit{Node(1), Node(2), Node(3)}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 3); + CHECK(num_edges(result) == 0); + CHECK(get_sources(result).size() == 3); + CHECK(get_sinks(result).size() == 3); + } + + SUBCASE("Mixed Serial-Parallel") { + SerialParallelDecomposition input = SerialParallelDecomposition{ + SerialSplit{ParallelSplit{Node(1), Node(2)}, + ParallelSplit{Node(3), Node(4)}}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 4); + CHECK(num_edges(result) == 4); + CHECK(get_sources(result).size() == 2); + CHECK(get_sinks(result).size() == 2); + } + + SUBCASE("Mixed Parallel-Serial") { + SerialParallelDecomposition input = + SerialParallelDecomposition{ParallelSplit{ + SerialSplit{Node(1), Node(2)}, SerialSplit{Node(3), Node(4)}}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 4); + CHECK(num_edges(result) == 2); + CHECK(get_sources(result).size() == 2); + CHECK(get_sinks(result).size() == 2); + } + + SUBCASE("Rhombus") { + SerialParallelDecomposition input = SerialParallelDecomposition{ + SerialSplit{Node(1), ParallelSplit{Node(2), Node(3)}, Node(4)}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 4); + CHECK(num_edges(result) == 4); + CHECK(get_sources(result).size() == 1); + CHECK(get_sinks(result).size() == 1); + } + + SUBCASE("Duplicate Nodes") { + SerialParallelDecomposition input = SerialParallelDecomposition{ + SerialSplit{Node(1), ParallelSplit{Node(1), Node(2)}, Node(1)}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 4); + CHECK(num_edges(result) == 4); + CHECK(get_sources(result).size() == 1); + CHECK(get_sinks(result).size() == 1); + } + + SUBCASE("Complex Graph") { + SerialParallelDecomposition input = SerialParallelDecomposition{ + SerialSplit{ParallelSplit{SerialSplit{ParallelSplit{Node(1), Node(2)}, + ParallelSplit{Node(3), Node(4)}, + Node(5)}, + SerialSplit{Node(6), Node(7)}}, + Node(8)}}; + + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 8); + CHECK(num_edges(result) == 9); + CHECK(get_sources(result).size() == 3); + CHECK(get_sinks(result).size() == 1); + } + } +} diff --git a/lib/utils/test/src/utils/graph/serial_parallel/normalize_sp_decomposition.cc b/lib/utils/test/src/utils/graph/serial_parallel/normalize_sp_decomposition.cc new file mode 100644 index 0000000000..d1300da137 --- /dev/null +++ b/lib/utils/test/src/utils/graph/serial_parallel/normalize_sp_decomposition.cc @@ -0,0 +1,72 @@ +#include "utils/graph/serial_parallel/normalize_sp_decomposition.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("normalize_sp_decomposition") { + Node n1 = Node(1); + Node n2 = Node(2); + Node n3 = Node(3); + + SUBCASE("Empty") { + SerialParallelDecomposition input = SerialParallelDecomposition{ + SerialSplit{ParallelSplit{}, ParallelSplit{}}}; + SerialParallelDecomposition correct = + SerialParallelDecomposition{SerialSplit{}}; + SerialParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Node Decomposition") { + SerialParallelDecomposition input = SerialParallelDecomposition{n1}; + SerialParallelDecomposition correct = SerialParallelDecomposition{n1}; + SerialParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Serial with Single Node") { + SerialParallelDecomposition input = + SerialParallelDecomposition{SerialSplit{n1}}; + SerialParallelDecomposition correct = SerialParallelDecomposition{n1}; + SerialParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Parallel with Single Node") { + SerialParallelDecomposition input = + SerialParallelDecomposition{ParallelSplit{n1}}; + SerialParallelDecomposition correct = SerialParallelDecomposition{n1}; + SerialParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Mixed Serial") { + SerialParallelDecomposition input = + SerialParallelDecomposition{SerialSplit{ParallelSplit{n1}, n2}}; + SerialParallelDecomposition correct = + SerialParallelDecomposition{SerialSplit{n1, n2}}; + SerialParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Mixed Parallel") { + SerialParallelDecomposition input = + SerialParallelDecomposition{ParallelSplit{SerialSplit{n1}, n2}}; + SerialParallelDecomposition correct = + SerialParallelDecomposition{ParallelSplit{n1, n2}}; + SerialParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Nested") { + SerialParallelDecomposition input = SerialParallelDecomposition{ + ParallelSplit{SerialSplit{ParallelSplit{n1, n2}}, n3, SerialSplit{}}}; + SerialParallelDecomposition correct = + SerialParallelDecomposition{ParallelSplit{n1, n2, n3}}; + SerialParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc index 7b8548eac7..b2ce2d71a1 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc @@ -156,4 +156,54 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set correct = {input}; CHECK(result == correct); } + + TEST_CASE("is_empty(SerialParallelDecomposition)") { + Node n1{1}; + Node n2{2}; + + SUBCASE("Node Decomposition") { + SerialParallelDecomposition sp{n1}; + CHECK_FALSE(is_empty(sp)); + } + + SUBCASE("Empty Serial") { + SerialParallelDecomposition sp{SerialSplit{}}; + CHECK(is_empty(sp)); + } + + SUBCASE("Empty Parallel") { + SerialParallelDecomposition sp{ParallelSplit{}}; + CHECK(is_empty(sp)); + } + + SUBCASE("Serial with Node") { + SerialParallelDecomposition sp{SerialSplit{n1}}; + CHECK_FALSE(is_empty(sp)); + } + + SUBCASE("Parallel with Node") { + SerialParallelDecomposition sp{ParallelSplit{n1}}; + CHECK_FALSE(is_empty(sp)); + } + + SUBCASE("Nested Serial") { + SerialParallelDecomposition sp{SerialSplit{ParallelSplit{}}}; + CHECK(is_empty(sp)); + } + + SUBCASE("Nested Parallel") { + SerialParallelDecomposition sp{ParallelSplit{SerialSplit{}}}; + CHECK(is_empty(sp)); + } + + SUBCASE("Sparse") { + SerialSplit sp{ParallelSplit{}, ParallelSplit{SerialSplit{}}}; + CHECK(is_empty(sp)); + } + + SUBCASE("Sparse with Node") { + SerialSplit sp{ParallelSplit{}, ParallelSplit{SerialSplit{}, n2}}; + CHECK_FALSE(is_empty(sp)); + } + } } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_splits.cc b/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_splits.cc new file mode 100644 index 0000000000..c08a926875 --- /dev/null +++ b/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_splits.cc @@ -0,0 +1,48 @@ +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ParallelSplit and SerialSplit equality checks") { + + SUBCASE("ParallelSplit::operator== - commutativity") { + ParallelSplit p1 = ParallelSplit{Node(1), Node(2), Node(3)}; + ParallelSplit p2 = ParallelSplit{Node(2), Node(1), Node(3)}; + ParallelSplit p3 = ParallelSplit{Node(3), Node(2), Node(1)}; + CHECK(p1 == p2); + CHECK(p2 == p3); + CHECK(p1 == p3); + } + + SUBCASE("SerialSplit::operator== - non-commutativity") { + SerialSplit p1 = SerialSplit{Node(1), Node(2), Node(3)}; + SerialSplit p2 = SerialSplit{Node(2), Node(1), Node(3)}; + SerialSplit p3 = SerialSplit{Node(3), Node(2), Node(1)}; + CHECK(p1 != p2); + CHECK(p2 != p3); + CHECK(p1 != p3); + } + + SUBCASE("operator==, mixed case, nested commutativity") { + std::vector n = {Node(0), Node(1), Node(2), Node(3)}; + + // All definitions are equivalent, since ParallelSplit commutes + ParallelSplit p1 = ParallelSplit{ + n.at(3), SerialSplit{ParallelSplit{n.at(2), n.at(1)}, n.at(2)}}; + ParallelSplit p2 = ParallelSplit{ + n.at(3), SerialSplit{ParallelSplit{n.at(1), n.at(2)}, n.at(2)}}; + ParallelSplit p3 = ParallelSplit{ + SerialSplit{ParallelSplit{n.at(1), n.at(2)}, n.at(2)}, n.at(3)}; + ParallelSplit p4 = ParallelSplit{ + SerialSplit{ParallelSplit{n.at(2), n.at(1)}, n.at(2)}, n.at(3)}; + + CHECK(p1 == p2); + CHECK(p1 == p3); + CHECK(p1 == p4); + CHECK(p2 == p3); + CHECK(p2 == p4); + CHECK(p3 == p4); + } + } +} diff --git a/lib/utils/test/src/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.cc b/lib/utils/test/src/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.cc new file mode 100644 index 0000000000..130fa93411 --- /dev/null +++ b/lib/utils/test/src/utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.cc @@ -0,0 +1,301 @@ +#include "utils/graph/serial_parallel/sp_ization/critical_path_preserving_sp_ization.h" +#include "test/utils/doctest.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/serial_parallel/serial_parallel_metrics.h" +#include "utils/graph/serial_parallel/serial_parallel_splits.h" + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("critical_path_preserving_sp_ization") { + + SUBCASE("Sample Graph #1") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + + std::unordered_map cost_map = {{n.at(0), 3}, + {n.at(1), 2}, + {n.at(2), 1}, + {n.at(3), 1}, + {n.at(4), 1}, + {n.at(5), 5}}; + + CHECK(work_cost(g, cost_map) == 13); + CHECK(critical_path_cost(g, cost_map) == 12); + + SerialParallelDecomposition sp = critical_path_preserving_sp_ization(g); + + SUBCASE("structure") { + Node sp0 = n.at(0); + SerialSplit sp1 = SerialSplit{sp0, n.at(1)}; + SerialSplit sp2 = SerialSplit{ParallelSplit{sp0, sp1}, n.at(2)}; + SerialSplit sp3 = SerialSplit{n.at(0), n.at(1), n.at(3)}; + SerialSplit sp4 = SerialSplit{ParallelSplit{sp2, sp3}, n.at(4)}; + SerialSplit sp5 = SerialSplit{ParallelSplit{sp3, sp4}, n.at(5)}; + SerialParallelDecomposition correct(sp5); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = 3 * 4 + 2 * 3 + 1 * 1 + 1 * 2 + 1 * 1 + 5 * 1; + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = critical_path_cost(g, cost_map); + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #2") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(5)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + + std::unordered_map cost_map = {{n.at(0), 1}, + {n.at(1), 1}, + {n.at(2), 10}, + {n.at(3), 1}, + {n.at(4), 1}, + {n.at(5), 1}}; + + CHECK(work_cost(g, cost_map) == 15); + CHECK(critical_path_cost(g, cost_map) == 12); + + SerialParallelDecomposition sp = critical_path_preserving_sp_ization(g); + + SUBCASE("structure") { + SerialParallelDecomposition correct(SerialSplit{ + ParallelSplit{SerialSplit{n.at(0), n.at(1), n.at(3), n.at(4)}, + SerialSplit{n.at(0), n.at(2)}}, + n.at(5)}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = 16; + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = critical_path_cost(g, cost_map); + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + } + + TEST_CASE("critical_path_preserving_sp_ization_with_coalescing") { + + SUBCASE("Sample Graph #1") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + + std::unordered_map cost_map = {{n.at(0), 1}, + {n.at(1), 1}, + {n.at(2), 2}, + {n.at(3), 3}, + {n.at(4), 1}, + {n.at(5), 1}}; + + CHECK(work_cost(g, cost_map) == 9); + CHECK(critical_path_cost(g, cost_map) == 7); + + SerialParallelDecomposition sp = + critical_path_preserving_sp_ization_with_coalescing(g); + + SUBCASE("structure") { + SerialParallelDecomposition correct(SerialSplit{ + n.at(0), + n.at(1), + ParallelSplit{SerialSplit{ParallelSplit{n.at(2), n.at(3)}, n.at(4)}, + n.at(3)}, + n.at(5)}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = 12; + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = critical_path_cost(g, cost_map); + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #2") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(5)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + + std::unordered_map cost_map = {{n.at(0), 1}, + {n.at(1), 1}, + {n.at(2), 10}, + {n.at(3), 1}, + {n.at(4), 1}, + {n.at(5), 1}}; + + CHECK(work_cost(g, cost_map) == 15); + CHECK(critical_path_cost(g, cost_map) == 12); + + SerialParallelDecomposition sp = + critical_path_preserving_sp_ization_with_coalescing(g); + + SUBCASE("structure") { + SerialParallelDecomposition correct(SerialSplit{ + n.at(0), + ParallelSplit{SerialSplit{n.at(1), n.at(3), n.at(4)}, n.at(2)}, + n.at(5)}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = 15; + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = critical_path_cost(g, cost_map); + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #3") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 10); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(6)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(3), n.at(8)}, + DirectedEdge{n.at(4), n.at(8)}, + DirectedEdge{n.at(5), n.at(7)}, + DirectedEdge{n.at(7), n.at(8)}, + DirectedEdge{n.at(6), n.at(9)}, + DirectedEdge{n.at(8), n.at(9)}}); + + std::unordered_map cost_map = {{n.at(0), 1}, + {n.at(1), 1}, + {n.at(2), 4}, + {n.at(3), 10}, + {n.at(4), 10}, + {n.at(5), 5}, + {n.at(6), 4}, + {n.at(7), 3}, + {n.at(8), 4}, + {n.at(9), 1}}; + + CHECK(work_cost(g, cost_map) == 43); + CHECK(critical_path_cost(g, cost_map) == 26); + + SerialParallelDecomposition sp = + critical_path_preserving_sp_ization_with_coalescing(g); + + SUBCASE("structure") { + + SerialParallelDecomposition correct(SerialSplit{ + n.at(0), + ParallelSplit{ + SerialSplit{n.at(1), n.at(2), n.at(6)}, + SerialSplit{ParallelSplit{ + SerialSplit{ParallelSplit{n.at(1), n.at(3)}, + ParallelSplit{ + n.at(4), + SerialSplit{n.at(5), n.at(7)}}}, + n.at(3)}, + n.at(8)}}, + n.at(9)}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + }; + SUBCASE("work cost") { + float correct = 54; + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = critical_path_cost(g, cost_map); + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + SUBCASE("Transitive Reduction") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + SerialParallelDecomposition result = + critical_path_preserving_sp_ization_with_coalescing(g); + SerialParallelDecomposition correct = + SerialParallelDecomposition{SerialSplit{ + {n.at(0), n.at(1), ParallelSplit{{n.at(2), n.at(3)}}, n.at(4)}}}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.cc b/lib/utils/test/src/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.cc new file mode 100644 index 0000000000..892af6ed26 --- /dev/null +++ b/lib/utils/test/src/utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.cc @@ -0,0 +1,364 @@ +#include "utils/graph/serial_parallel/sp_ization/work_preserving_sp_ization.h" +#include "test/utils/doctest.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/serial_parallel/serial_parallel_metrics.h" +#include "utils/graph/serial_parallel/serial_parallel_splits.h" + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("work_preserving_") { + + SUBCASE("Sample Graph #1") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + }); + + std::unordered_map cost_map = { + {n[0], 1}, {n[1], 1}, {n[2], 2}, {n[3], 3}, {n[4], 1}, {n[5], 1}}; + + CHECK(work_cost(g, cost_map) == 9); + CHECK(critical_path_cost(g, cost_map) == 7); + + SerialParallelDecomposition sp = stratum_sync_sp_ization(g); + + SUBCASE("structure") { + SerialParallelDecomposition correct( + SerialSplit{n[0], n[1], ParallelSplit{n[2], n[3]}, n[4], n[5]}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 7; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #2") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[5]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}, + }); + + std::unordered_map cost_map = { + {n[0], 1}, {n[1], 1}, {n[2], 10}, {n[3], 1}, {n[4], 1}, {n[5], 1}}; + + CHECK(work_cost(g, cost_map) == 15); + CHECK(critical_path_cost(g, cost_map) == 12); + + SerialParallelDecomposition sp = stratum_sync_sp_ization(g); + + SUBCASE("structure") { + SerialParallelDecomposition correct( + SerialSplit{n[0], ParallelSplit{n[1], n[2]}, n[3], n[4], n[5]}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 14; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #3") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 9); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[5]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[6]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[3], n[8]}, + DirectedEdge{n[4], n[8]}, + DirectedEdge{n[5], n[7]}, + DirectedEdge{n[7], n[8]}, + }); + + std::unordered_map cost_map = {{n[0], 1}, + {n[1], 1}, + {n[2], 10}, + {n[3], 10}, + {n[4], 1}, + {n[5], 1}, + {n[6], 10}, + {n[7], 10}, + {n[8], 1}}; + + CHECK(work_cost(g, cost_map) == 45); + CHECK(critical_path_cost(g, cost_map) == 23); + + SerialParallelDecomposition sp = stratum_sync_sp_ization(g); + + SUBCASE("structure") { + SerialParallelDecomposition correct( + SerialSplit{n[0], + ParallelSplit{n[1], n[3]}, + ParallelSplit{n[2], n[4], n[5]}, + ParallelSplit{n[6], n[7]}, + n[8]}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 32; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + } + + TEST_CASE("cost_aware_stratum_sync_sp_ization") { + + SUBCASE("Sample Graph #1") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + }); + + std::unordered_map cost_map = { + {n[0], 1}, {n[1], 1}, {n[2], 2}, {n[3], 3}, {n[4], 1}, {n[5], 1}}; + + CHECK(work_cost(g, cost_map) == 9); + CHECK(critical_path_cost(g, cost_map) == 7); + + SerialParallelDecomposition sp = + cost_aware_stratum_sync_sp_ization(g, cost_map); + + SUBCASE("structure") { + SerialParallelDecomposition correct( + SerialSplit{n[0], n[1], ParallelSplit{n[2], n[3]}, n[4], n[5]}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 7; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #2") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[5]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}, + }); + + std::unordered_map cost_map = { + {n[0], 1}, {n[1], 1}, {n[2], 10}, {n[3], 1}, {n[4], 1}, {n[5], 1}}; + + CHECK(work_cost(g, cost_map) == 15); + CHECK(critical_path_cost(g, cost_map) == 12); + + SerialParallelDecomposition sp = + cost_aware_stratum_sync_sp_ization(g, cost_map); + + SUBCASE("structure") { + SerialParallelDecomposition correct(SerialSplit{ + n[0], ParallelSplit{SerialSplit{n[1], n[3], n[4]}, n[2]}, n[5]}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 12; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #3") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 9); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[5]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[6]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[3], n[8]}, + DirectedEdge{n[4], n[8]}, + DirectedEdge{n[5], n[7]}, + DirectedEdge{n[7], n[8]}, + }); + + std::unordered_map cost_map = {{n[0], 1}, + {n[1], 1}, + {n[2], 4}, + {n[3], 10}, + {n[4], 10}, + {n[5], 5}, + {n[6], 4}, + {n[7], 3}, + {n[8], 4}}; + + CHECK(work_cost(g, cost_map) == 42); + CHECK(critical_path_cost(g, cost_map) == 25); + + SerialParallelDecomposition sp = + cost_aware_stratum_sync_sp_ization(g, cost_map); + + SUBCASE("structure") { + SerialParallelDecomposition correct( + SerialSplit{n[0], + ParallelSplit{SerialSplit{n[1], n[2], n[6]}, n[3]}, + ParallelSplit{n[4], SerialSplit{n[5], n[7]}}, + n[8]}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 25; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #4") { + + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 15); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, DirectedEdge{n[0], n[5]}, + DirectedEdge{n[1], n[2]}, DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, DirectedEdge{n[2], n[7]}, + DirectedEdge{n[3], n[4]}, DirectedEdge{n[3], n[6]}, + DirectedEdge{n[4], n[6]}, DirectedEdge{n[5], n[6]}, + DirectedEdge{n[5], n[10]}, DirectedEdge{n[6], n[9]}, + DirectedEdge{n[6], n[11]}, DirectedEdge{n[7], n[8]}, + DirectedEdge{n[8], n[9]}, DirectedEdge{n[8], n[13]}, + DirectedEdge{n[9], n[13]}, DirectedEdge{n[10], n[11]}, + DirectedEdge{n[10], n[12]}, DirectedEdge{n[11], n[14]}, + DirectedEdge{n[12], n[14]}, DirectedEdge{n[13], n[14]}, + }); + + std::unordered_map cost_map = {{n[0], 1}, + {n[1], 1}, + {n[2], 3}, + {n[3], 3}, + {n[4], 1}, + {n[5], 5}, + {n[6], 5}, + {n[7], 1}, + {n[8], 1}, + {n[9], 1}, + {n[10], 3}, + {n[11], 3}, + {n[12], 2}, + {n[13], 1}, + {n[14], 10}}; + + CHECK(work_cost(g, cost_map) == 41); + CHECK(critical_path_cost(g, cost_map) == 24); + + SerialParallelDecomposition sp = + cost_aware_stratum_sync_sp_ization(g, cost_map); + + SUBCASE("structure") { + SerialParallelDecomposition correct(SerialSplit{ + n[0], + ParallelSplit{SerialSplit{n[1], ParallelSplit{n[2], n[3]}, n[7]}, + n[5]}, + ParallelSplit{n[4], n[8], n[10]}, + ParallelSplit{n[6], n[12]}, + ParallelSplit{n[11], SerialSplit{n[9], n[13]}}, + n[14]}); + SerialParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 27; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + } +}