From ae8bbf915b3de5a376d741b516944b75bb8cdfb3 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 14 Sep 2024 15:10:02 -0700 Subject: [PATCH 1/3] Add tool for exporting and visualizing model architectures and SP decompositions (#1490) * Start on pcg builder * Add tests and some implementation for pcg builder * Add pcg tests, make dtgen constructors explicit to fix bug * Add remainder of PCG tests * Fix build issues in local-execution * Format * Address Reyna comments, add topological_order function for PCG * Pre multidigraph refactor * Removing visitable from sp code * Add open dataflow graph, start to replace pcg dataflow graph * Start refactoring substitutions * Add utility functions to support pattern matching * Pre-refactor inputs * Fix proj url * Get back to substitutions, now with unordered graph inputs * Get substitutions building * substitutions-tests now builds * Fix bug in filter, pass some initial substitution tests * Add tests for fmt::to_string, fix some substitutions bugs * Pass initial unit tests for find_pattern_matches * Start on unit tests for pcg pattern * Pass initial test for find_pattern_matches * Fix small build issue in tests * Format * Sync tests in CI with tests in proj * Fix minor build errors in kernels and local-execution * Format * Remove outdated code * More outdated code removal * More cleanup, add test for sp decomposition * Pull apart containers.h * More sp testing and fixes * Break up graph algorithms.h * Pre- full SP algo commit * Add initial implementation and tests for cbc decomposition and inverse line graph * Pass test for get_inverse_line_graph * Add new multidigraph * Fix get_inverse_line_graph to return a MultiDiGraph instead of a DiGraph * Add tests for parallel and series reduction finding * Add really rough implementation of valdez sp decomposition * Fix local-execution build * Add implementations and tests for applying series/parallel reductions * Add transformer pcg prototype * Format * Clean up sp decomposition interface and tests * Format * Add comments for top-level substitutions functions, add proj doxygen support * Start sketching out substitutions code * Fix build errors * Add ability to permute node ids * Cleanup and start to test new substitutions code * Add test case for evaluate_substitution_output * Add naive isomorphism detection code * Add graph inputs to open dataflow graph isomorphism * Add input permutation to evaluate_substitution_output * Update based on review comments * Fix permute_node_ids * Add test for permute_input_ids * Add models makefile and test * Update * Pass test * Enhance Transformer implementation * Reflect review comments * Migrate over to mutable implementation of apply_substitution * Add fast isomorphism checking and an initial implementation of full substitution logic * Pass initial full substitutions test * Cleanup old isomorphism checking code * Fix post-merge bugs * [WIP] Save initial refactor * Fix broken pcg builder test * Format * Reorganize code and remove some outdated code pre-code-review * Format * Implement actual encorder decoder architecture * Remove duplicated definition * Update based on review * Update argument order * Address review comments * Address missed comment * Remove latex dependency to avoid CI out-of-disk-space * Format * Implement most of the shape inference and ComputationGraphBuilder support * Fix bug in LayerNorm shape inference tests, disable {?} doctest default * Fix transformer test * Format * Get initial export-model-arch binary building * Actually dump valid json from export-model-arch * Some minor polishing of export-model-arch * Add binary sp tree logic and start on sp decomposition of computation graphs * Fix sp decomposition export of transformer (a lot of cleanup now needed) * Flesh out export-model-arch CLI and features * Format * Add split_test model * Add single_operator model for testing * Add substitution-to-dot and export-model-arch build to CI * Cleanup generic_binary_sp_decomposition_tree * Format * Fix substitution-to-dot name in CI * Add testing for cli_get_help_message * Add testing for cli_parse * Add missing include in export_model_arch * Add basic test for cli_parse on raw argv * Rename serial-parallel -> series-parallel * Add a bunch of testing for new code * Add tests for computation graph sp decomposition * Format * Fix build error in export-model-arch --------- Co-authored-by: hsdfzhsdfz --- .../{build_libs.sh => build_target.sh} | 0 .../helpers/{test_libs.sh => test_target.sh} | 0 .github/workflows/per-lib-check.yml | 42 +- .proj.toml | 2 + CMakeLists.txt | 1 + bin/CMakeLists.txt | 6 +- bin/export-model-arch/CMakeLists.txt | 12 + .../json_sp_model_export.struct.toml | 27 + .../src/export_model_arch.cc | 208 +++++++ .../CMakeLists.txt | 0 .../substitution_to_dot.cc | 21 +- cmake/flexflow-utils.cmake | 5 + lib/compiler/include/compiler/graph_utils.h | 6 +- .../include/compiler/machine_mapping.h | 3 +- .../compiler/optimal_cost_state.struct.toml | 6 +- ...omputation_graph_binary_sp_decomposition.h | 30 + ..._graph_binary_sp_decomposition.struct.toml | 22 + ...tion_graph_series_parallel_decomposition.h | 17 + ...mputation_graph_binary_sp_decomposition.cc | 90 +++ ...ion_graph_series_parallel_decomposition.cc | 98 ++++ lib/compiler/src/graph_utils.cc | 14 +- lib/compiler/src/machine_mapping.cc | 67 +-- lib/compiler/test/CMakeLists.txt | 1 + ...ion_graph_series_parallel_decomposition.cc | 340 ++++++++++++ lib/compiler/test/src/test_generator.h | 10 +- lib/kernels/include/kernels/accessor.h | 16 +- lib/kernels/include/kernels/array_shape.h | 2 +- .../include/kernels/attention_kernels.h | 1 - .../include/kernels/batch_matmul_kernels.h | 1 - .../include/kernels/initializer_kernels.h | 1 + lib/kernels/src/allocation.cc | 1 + lib/kernels/src/cpu/initializer_kernels.cc | 2 +- lib/kernels/src/cuda/embedding_kernels.cu | 8 +- lib/kernels/src/cuda/ops/combine_kernels.cu | 2 +- .../src/cuda/ops/element_unary_kernels.cu | 12 +- lib/kernels/src/cuda/ops/partition_kernels.cu | 12 +- lib/kernels/src/cuda/ops/reduction_kernels.cu | 2 +- lib/kernels/src/cuda/ops/replicate_kernels.cu | 2 +- lib/kernels/src/cuda/ops/reshape_kernels.cu | 4 +- lib/kernels/src/hip/ops/replicate_kernels.cpp | 19 +- lib/kernels/src/hip/ops/reshape_kernels.cpp | 4 +- .../include/local-execution/cost_estimate.h | 4 +- .../local-execution/legion_tensor_shape.h | 3 +- .../local-execution/local_slots_backing.h | 3 + .../local-execution/local_training_backing.h | 1 + .../include/local-execution/op_arg_ref.h | 2 +- .../local-execution/op_task_invocation.h | 4 - .../include/local-execution/sim_environment.h | 2 +- .../local-execution/task_registry.struct.toml | 1 + .../src/legion_tensor_shape.cc | 1 + .../src/local_cost_estimator.cc | 2 +- .../src/local_slots_backing.cc | 2 + .../src/local_training_backing.cc | 1 + lib/local-execution/src/op_task_signature.cc | 1 + lib/local-execution/src/ops/element_unary.cc | 1 + .../test/src/test_local_cost_estimator.cc | 6 +- .../test/src/test_local_slots_backing.cc | 22 +- .../test/src/test_local_task_arg_accessor.cc | 4 +- .../test/src/test_task_registry.cc | 4 +- lib/models/CMakeLists.txt | 3 +- .../include/models/split_test/split_test.h | 19 + .../models/{ => transformer}/transformer.h | 6 +- .../transformer_config.struct.toml | 0 .../src/models/split_test/split_test.cc | 39 ++ .../models/{ => transformer}/transformer.cc | 10 +- lib/models/test/src/models/transformer.cc | 4 +- lib/op-attrs/include/op-attrs/datatype.h | 9 +- .../op-attrs/datatype_value.variant.toml | 25 + lib/op-attrs/include/op-attrs/dim_ordered.h | 7 +- .../include/op-attrs/dim_ordered/slice.h | 4 +- .../include/op-attrs/dim_ordered/transform.h | 4 +- .../include/op-attrs/dim_ordered/zip.h | 4 +- lib/op-attrs/include/op-attrs/ops/broadcast.h | 3 + .../op-attrs/ops/conv_2d_attrs.struct.toml | 3 +- .../ops/element_unary_attrs.struct.toml | 6 +- .../op-attrs/ops/embedding_attrs.struct.toml | 2 + .../op-attrs/ops/linear_attrs.struct.toml | 4 +- lib/op-attrs/include/op-attrs/tensor_shape.h | 5 - .../op-attrs/computation_graph_op_attrs.cc | 15 + lib/op-attrs/src/op-attrs/ops/broadcast.cc | 18 + .../src/op-attrs/parallel_tensor_dims.cc | 4 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 8 +- lib/op-attrs/src/op-attrs/tensor_shape.cc | 31 -- lib/op-attrs/test/src/datatype.cc | 4 +- lib/op-attrs/test/src/dim_ordered/slice.cc | 4 +- .../src/op-attrs/dim_ordered/enumerate.cc | 2 +- .../test/src/op-attrs/dim_ordered/zip.cc | 2 +- lib/op-attrs/test/src/op-attrs/ops/dropout.cc | 1 + .../test/src/op-attrs/ops/layer_norm.cc | 2 +- lib/op-attrs/test/src/op-attrs/ops/softmax.cc | 1 + lib/op-attrs/test/src/op-attrs/tensor_dims.cc | 1 + .../test/src/op-attrs/tensor_shape.cc | 64 --- lib/op-attrs/test/src/ops/attention.cc | 5 +- lib/op-attrs/test/src/ops/batch_matmul.cc | 5 +- lib/op-attrs/test/src/ops/cast.cc | 5 +- lib/op-attrs/test/src/ops/combine.cc | 5 +- lib/op-attrs/test/src/ops/element_binary.cc | 5 +- lib/op-attrs/test/src/ops/element_unary.cc | 5 +- lib/op-attrs/test/src/ops/embedding.cc | 5 +- lib/op-attrs/test/src/ops/linear.cc | 5 +- lib/op-attrs/test/src/ops/reduction.cc | 5 +- lib/op-attrs/test/src/ops/repartition.cc | 5 +- lib/op-attrs/test/src/ops/replicate.cc | 4 +- lib/op-attrs/test/src/test_operator_attrs.cc | 14 +- .../test/src/test_regularizer_attrs.cc | 4 +- lib/pcg/CMakeLists.txt | 1 + lib/pcg/include/pcg/computation_graph.h | 14 + .../computation_graph_edge.h | 14 + .../computation_graph_edge.struct.toml | 16 + .../include/pcg/computation_graph_builder.h | 35 +- lib/pcg/include/pcg/file_format/file_format.h | 19 - .../include/pcg/file_format/keyed_variant.h | 29 +- .../pcg/file_format/v1/data_type_value.h | 2 +- lib/pcg/include/pcg/file_format/v1/graphs.h | 26 - .../v1/graphs/v1_dataflow_graph.struct.toml | 9 +- .../v1/graphs/v1_labelled_dataflow_graph.h | 18 +- .../v1_labelled_dataflow_graph.struct.toml | 12 +- lib/pcg/include/pcg/file_format/v1/v1.h | 9 - .../pcg/file_format/v1/v1_computation_graph.h | 17 + .../v1/v1_computation_graph.struct.toml | 18 + .../v1/v1_parallel_computation_graph.h | 13 + .../v1_parallel_computation_graph.struct.toml | 18 + .../constant_initializer_attrs.struct.toml | 7 +- lib/pcg/include/pcg/layer_attrs.struct.toml | 2 +- .../parallel_layer_attrs.struct.toml | 2 + .../parallel_tensor_attrs.struct.toml | 2 + lib/pcg/include/pcg/tensor_attrs.struct.toml | 1 + lib/pcg/src/file_format.cc | 14 - lib/pcg/src/file_format/v1/graphs.cc | 16 - lib/pcg/src/pcg/computation_graph.cc | 104 +++- .../computation_graph_edge.cc | 15 + lib/pcg/src/pcg/computation_graph_builder.cc | 165 ++++-- .../file_format/v1/v1_computation_graph.cc | 24 + .../v1/v1_parallel_computation_graph.cc | 12 + .../file_format/v1/v1_computation_graph.cc | 30 + .../v1/v1_parallel_computation_graph.cc | 36 ++ .../initializers/uniform_initializer_attrs.cc | 4 +- .../parallel_computation_graph_builder.cc | 6 +- .../src/test_computation_graph_builder.cc | 2 +- lib/pcg/test/src/test_machine_view.cc | 2 +- lib/pcg/test/src/test_strided_rectangle.cc | 2 +- lib/runtime/src/accessor.cc | 6 +- .../operator_attribute_value.variant.toml | 1 + .../sub_parallel_computation_graph.h | 4 - .../operator_pattern/get_attribute.cc | 4 +- .../tensor_pattern/get_attribute.cc | 6 +- .../operator_pattern/get_attribute.cc | 1 + .../test/src/substitutions/pcg_pattern.cc | 2 +- .../substitutions/unlabelled/pattern_split.cc | 2 +- .../unlabelled/unlabelled_graph_pattern.cc | 2 +- .../test/src/test_pattern_matches.cc | 4 +- lib/utils/CMakeLists.txt | 1 - .../utils/cli/cli_argument_key.variant.toml | 19 + .../utils/cli/cli_flag_key.struct.toml | 13 + .../utils/cli/cli_flag_spec.struct.toml | 28 + .../include/utils/cli/cli_get_help_message.h | 13 + lib/utils/include/utils/cli/cli_parse.h | 19 + .../include/utils/cli/cli_parse_result.h | 14 + .../utils/cli/cli_parse_result.struct.toml | 27 + .../cli_positional_argument_key.struct.toml | 13 + .../cli_positional_argument_spec.struct.toml | 31 ++ lib/utils/include/utils/cli/cli_spec.h | 20 + .../include/utils/cli/cli_spec.struct.toml | 29 + lib/utils/include/utils/containers.decl.h | 3 - lib/utils/include/utils/containers.h | 5 - .../utils/containers/enumerate_vector.h | 1 - lib/utils/include/utils/containers/foldl1.h | 29 + lib/utils/include/utils/containers/foldr1.h | 28 + .../include/utils/containers/generate_map.h | 4 +- .../include/utils/containers/get_first.h | 6 + lib/utils/include/utils/containers/maximum.h | 20 + .../include/utils/containers/multiset_union.h | 48 ++ .../utils/containers/require_no_duplicates.h | 40 ++ lib/utils/include/utils/containers/reversed.h | 11 +- .../include/utils/containers/set_minus.h | 10 + lib/utils/include/utils/containers/set_of.h | 19 + .../include/utils/containers/to_uppercase.h | 12 + .../containers/{as_vector.h => vector_of.h} | 6 +- lib/utils/include/utils/fmt/expected.h | 14 +- lib/utils/include/utils/fmt/map.h | 12 - lib/utils/include/utils/fmt/multiset.h | 12 - lib/utils/include/utils/fmt/optional.h | 12 - lib/utils/include/utils/fmt/pair.h | 12 - lib/utils/include/utils/fmt/set.h | 12 - lib/utils/include/utils/fmt/unordered_map.h | 12 - .../include/utils/fmt/unordered_multiset.h | 12 - lib/utils/include/utils/fmt/unordered_set.h | 12 - lib/utils/include/utils/fmt/variant.h | 12 - lib/utils/include/utils/fmt/vector.h | 12 - .../algorithms/get_subgraph_incoming_edges.h | 14 + .../get_cbc_decomposition.h | 3 + .../is_complete_bipartite_digraph.h | 14 + .../graph/digraph/algorithms/digraph_as_dot.h | 14 + .../digraph/algorithms/digraph_has_edge.h | 12 + .../algorithms/get_subgraph_outgoing_edges.h | 14 + .../algorithms/get_subgraph_successors.h | 14 + .../digraph/algorithms/transitive_closure.h | 12 + .../unordered_set_undirected_graph.h | 37 ++ .../algorithms/find_isomorphism.h | 1 - .../include/utils/graph/node/node.struct.toml | 1 + .../get_serial_parallel_decomposition.h | 17 - .../intermediate_sp_decomposition_tree.h | 13 - .../serial_parallel_decomposition.h | 22 - .../binary_sp_decomposition_tree.h | 23 + .../binary_sp_decomposition_tree.struct.toml | 22 + .../fmt.h | 63 +++ .../generic_binary_sp_decomposition_tree.h | 155 ++++++ .../get.h | 15 + .../get_leaves.h | 40 ++ .../get_left_child.h | 44 ++ .../get_node_type.h | 29 + .../get_num_tree_nodes.h | 40 ++ .../get_right_child.h | 44 ++ .../hash.h | 34 ++ .../generic_binary_sp_decomposition_tree/is.h | 25 + .../is_binary_sp_tree_left_associative.h | 34 ++ .../is_binary_sp_tree_right_associative.h | 34 ++ .../json.h | 103 ++++ .../make.h | 39 ++ .../require.h | 28 + .../transform.h | 43 ++ .../visit.h | 37 ++ ...eft_associative_binary_sp_tree_from_nary.h | 14 + .../nary_sp_tree_from_binary.h | 14 + ...ght_associative_binary_sp_tree_from_nary.h | 14 + .../get_series_parallel_decomposition.h | 17 + .../graph_generation.h | 12 +- .../intermediate_sp_decomposition_tree.h | 17 + ...rmediate_sp_decomposition_tree.struct.toml | 2 +- .../parallel_reduction.h | 6 +- .../parallel_reduction.struct.toml | 0 .../series_parallel_decomposition.h | 22 + ...eries_parallel_decomposition.variant.toml} | 6 +- .../series_parallel_splits.h} | 32 +- .../series_reduction.h | 6 +- .../series_reduction.struct.toml | 0 .../sink_settings.enum.toml | 0 .../source_settings.enum.toml | 0 .../sp_decomposition_tree_node_type.enum.toml | 17 + .../split_type.enum.toml | 2 +- .../graph/undirected/algorithms/get_edges.h | 12 + .../algorithms/get_neighboring_nodes.h | 13 + .../graph/undirected/i_undirected_graph.h | 2 +- .../graph/undirected/undirected_edge_query.h | 2 + lib/utils/include/utils/hash/multiset.h | 20 + .../include/utils/hash/unordered_multiset.h | 20 + .../include/utils/json/check_is_jsonable.h | 17 + .../utils/json/is_json_deserializable.h | 25 + .../include/utils/json/is_json_serializable.h | 24 + lib/utils/include/utils/json/is_jsonable.h | 18 + lib/utils/include/utils/json/optional.h | 33 ++ lib/utils/include/utils/json/variant.h | 89 +++ .../utils/{json.h => json/visitable.h} | 125 +---- lib/utils/include/utils/optional.h | 23 +- lib/utils/include/utils/rapidcheck/optional.h | 21 + lib/utils/include/utils/required.h | 10 +- lib/utils/include/utils/stack_string.h | 6 +- lib/utils/include/utils/stack_vector.h | 6 +- .../src/utils/cli/cli_get_help_message.cc | 101 ++++ lib/utils/src/utils/cli/cli_parse.cc | 96 ++++ lib/utils/src/utils/cli/cli_parse_result.cc | 14 + lib/utils/src/utils/cli/cli_spec.cc | 37 ++ lib/utils/src/utils/containers/as_vector.cc | 1 - .../src/utils/containers/enumerate_vector.cc | 1 + lib/utils/src/utils/containers/foldl1.cc | 1 + lib/utils/src/utils/containers/foldr1.cc | 1 + .../utils/containers/get_element_counts.cc | 4 +- lib/utils/src/utils/containers/maximum.cc | 1 + .../src/utils/containers/multiset_union.cc | 1 + .../utils/containers/require_no_duplicates.cc | 1 + lib/utils/src/utils/containers/set_of.cc | 1 + .../src/utils/containers/to_uppercase.cc | 10 + lib/utils/src/utils/containers/vector_of.cc | 1 + lib/utils/src/utils/graph/algorithms.cc | 4 - .../algorithms/get_subgraph_incoming_edges.cc | 24 + .../get_cbc_decomposition.cc | 37 +- .../is_complete_bipartite_digraph.cc | 29 + .../digraph/algorithms/digraph_as_dot.cc | 32 ++ .../digraph/algorithms/digraph_has_edge.cc | 13 + .../algorithms/get_imm_dominators_map.cc | 6 +- .../algorithms/get_subgraph_outgoing_edges.cc | 16 + .../algorithms/get_subgraph_successors.cc | 16 + .../digraph/algorithms/transitive_closure.cc | 51 ++ .../algorithms/transitive_reduction.cc | 64 ++- .../graph/instances/adjacency_digraph.cc | 6 +- .../unordered_set_undirected_graph.cc | 58 ++ .../algorithms/get_edge_counts.cc | 4 +- .../algorithms/find_isomorphisms.cc | 6 +- .../intermediate_sp_decomposition_tree.cc | 48 -- .../binary_sp_decomposition_tree.cc | 43 ++ .../fmt.cc | 1 + .../generic_binary_sp_decomposition_tree.cc | 1 + .../get.cc | 1 + .../get_leaves.cc | 1 + .../get_left_child.cc | 1 + .../get_node_type.cc | 1 + .../get_num_tree_nodes.cc | 1 + .../get_right_child.cc | 1 + .../hash.cc | 1 + .../is.cc | 1 + .../is_binary_sp_tree_left_associative.cc | 1 + .../is_binary_sp_tree_right_associative.cc | 1 + .../json.cc | 1 + .../make.cc | 1 + .../require.cc | 1 + .../transform.cc | 1 + .../visit.cc | 1 + ...ft_associative_binary_sp_tree_from_nary.cc | 75 +++ .../nary_sp_tree_from_binary.cc | 12 + ...ht_associative_binary_sp_tree_from_nary.cc | 72 +++ .../get_series_parallel_decomposition.cc} | 48 +- .../graph_generation.cc | 12 +- .../intermediate_sp_decomposition_tree.cc | 84 +++ .../parallel_reduction.cc | 2 +- .../series_parallel_decomposition.cc} | 45 +- .../series_parallel_splits.cc} | 30 +- .../series_reduction.cc | 2 +- .../graph/undirected/algorithms/get_edges.cc | 10 + .../algorithms/get_neighboring_nodes.cc | 19 + .../graph/undirected/undirected_edge_query.cc | 4 + lib/utils/src/utils/hash/multiset.cc | 1 + .../src/utils/hash/unordered_multiset.cc | 1 + lib/utils/src/utils/json/check_is_jsonable.cc | 1 + .../src/utils/json/is_json_deserializable.cc | 1 + .../src/utils/json/is_json_serializable.cc | 1 + lib/utils/src/utils/json/is_jsonable.cc | 1 + lib/utils/src/utils/json/optional.cc | 1 + lib/utils/src/utils/rapidcheck/optional.cc | 1 + .../test/common/include/test/utils/all.h | 2 - .../check_without_stringify.h} | 0 .../include/test/utils/doctest/fmt/expected.h | 18 + .../include/test/utils/doctest/fmt/map.h | 18 + .../include/test/utils/doctest/fmt/multiset.h | 18 + .../include/test/utils/doctest/fmt/optional.h | 18 + .../include/test/utils/doctest/fmt/pair.h | 18 + .../include/test/utils/doctest/fmt/set.h | 18 + .../test/utils/doctest/fmt/unordered_map.h | 18 + .../utils/doctest/fmt/unordered_multiset.h | 18 + .../test/utils/doctest/fmt/unordered_set.h | 18 + .../include/test/utils/doctest/fmt/variant.h | 18 + .../include/test/utils/doctest/fmt/vector.h | 18 + lib/utils/test/common/src/common.cc | 1 - .../src/test/utils/doctest/fmt/expected.cc | 1 + .../common/src/test/utils/doctest/fmt/map.cc | 1 + .../src/test/utils/doctest/fmt/multiset.cc | 1 + .../src/test/utils/doctest/fmt/optional.cc | 1 + .../common/src/test/utils/doctest/fmt/pair.cc | 1 + .../common/src/test/utils/doctest/fmt/set.cc | 1 + .../test/utils/doctest/fmt/unordered_map.cc | 1 + .../utils/doctest/fmt/unordered_multiset.cc | 1 + .../test/utils/doctest/fmt/unordered_set.cc | 1 + .../src/test/utils/doctest/fmt/variant.cc | 1 + .../src/test/utils/doctest/fmt/vector.cc | 1 + lib/utils/test/src/test_algorithms.cc | 2 +- lib/utils/test/src/test_containers.cc | 6 +- .../src/test_deduplicated_priority_queue.cc | 2 +- lib/utils/test/src/test_disjoint_set.cc | 2 +- lib/utils/test/src/test_dot_file.cc | 2 +- lib/utils/test/src/test_format.cc | 2 +- lib/utils/test/src/test_hash.cc | 2 +- lib/utils/test/src/test_multidigraph.cc | 2 +- lib/utils/test/src/test_random_utils.cc | 2 +- lib/utils/test/src/test_sequence.cc | 2 +- lib/utils/test/src/test_stack_map.cc | 2 +- lib/utils/test/src/test_stack_string.cc | 2 +- lib/utils/test/src/test_stack_vector.cc | 2 +- lib/utils/test/src/test_tuple.cc | 2 +- lib/utils/test/src/test_type_index.cc | 2 +- lib/utils/test/src/test_undirected_graph.cc | 3 +- lib/utils/test/src/test_variant.cc | 2 +- lib/utils/test/src/test_vector.cc | 2 +- .../algorithms/bidict_from_enumerating.cc | 2 +- lib/utils/test/src/utils/bidict/bidict.cc | 6 +- .../bidict/try_merge_nondisjoint_bidicts.cc | 4 +- .../src/utils/cli/cli_get_help_message.cc | 519 ++++++++++++++++++ lib/utils/test/src/utils/cli/cli_parse.cc | 477 ++++++++++++++++ .../test/src/utils/containers/contains_key.cc | 5 +- .../test/src/utils/containers/enumerate.cc | 26 +- lib/utils/test/src/utils/containers/extend.cc | 4 +- lib/utils/test/src/utils/containers/filter.cc | 12 +- .../src/utils/containers/filtermap_keys.cc | 6 +- .../src/utils/containers/filtermap_values.cc | 6 +- .../test/src/utils/containers/filtrans.cc | 6 +- lib/utils/test/src/utils/containers/foldl1.cc | 27 + lib/utils/test/src/utils/containers/foldr1.cc | 27 + .../utils/containers/get_all_permutations.cc | 5 +- .../utils/containers/get_element_counts.cc | 2 +- .../src/utils/containers/inplace_filter.cc | 13 +- .../test/src/utils/containers/intersection.cc | 4 +- .../test/src/utils/containers/maximum.cc | 60 ++ .../src/utils/containers/multiset_union.cc | 29 + lib/utils/test/src/utils/containers/repeat.cc | 2 +- .../utils/containers/require_no_duplicates.cc | 62 +++ .../test/src/utils/containers/reversed.cc | 27 + .../test/src/utils/containers/to_uppercase.cc | 15 + .../test/src/utils/containers/transform.cc | 6 +- .../try_merge_nondisjoint_unordered_maps.cc | 6 +- .../utils/containers/unordered_multiset_of.cc | 2 +- .../src/utils/containers/unordered_set_of.cc | 2 +- .../test/src/utils/containers/vector_of.cc | 17 + .../src/utils/containers/without_order.cc | 2 +- lib/utils/test/src/utils/expected.cc | 4 +- lib/utils/test/src/utils/fmt/expected.cc | 23 +- lib/utils/test/src/utils/fmt/map.cc | 2 +- lib/utils/test/src/utils/fmt/optional.cc | 2 +- lib/utils/test/src/utils/fmt/pair.cc | 2 +- lib/utils/test/src/utils/fmt/set.cc | 2 +- lib/utils/test/src/utils/fmt/unordered_map.cc | 3 +- lib/utils/test/src/utils/fmt/unordered_set.cc | 4 +- lib/utils/test/src/utils/fmt/variant.cc | 2 +- lib/utils/test/src/utils/fmt/vector.cc | 2 +- lib/utils/test/src/utils/graph/cow_ptr_t.cc | 2 +- .../algorithms/get_subgraph_incoming_edges.cc | 43 ++ .../algorithms/get_subgraph_outgoing_edges.cc | 3 +- .../unordered_open_dataflow_graph.cc | 4 +- .../get_cbc_decomposition.cc | 45 ++ .../is_complete_bipartite_graph.cc | 175 ++++++ .../get_inverse_line_graph.cc | 23 + .../graph/digraph/algorithms/is_acyclic.cc | 1 + .../digraph/algorithms/transitive_closure.cc | 50 ++ .../algorithms/transitive_reduction.cc | 62 +++ .../fmt.cc | 51 ++ .../get_leaves.cc | 86 +++ .../get_left_child.cc | 41 ++ .../get_num_tree_nodes.cc | 85 +++ .../get_right_child.cc | 41 ++ .../hash.cc | 117 ++++ .../is_binary_sp_tree_left_associative.cc | 102 ++++ .../is_binary_sp_tree_right_associative.cc | 102 ++++ .../json.cc | 131 +++++ .../transform.cc | 28 + ...ft_associative_binary_sp_tree_from_nary.cc | 95 ++++ .../nary_sp_tree_from_binary.cc | 132 +++++ ...ht_associative_binary_sp_tree_from_nary.cc | 93 ++++ .../get_series_parallel_decomposition.cc} | 108 ++-- .../intermediate_sp_decomposition_tree.cc | 8 +- .../parallel_reduction.cc | 2 +- .../series_parallel_decomposition.cc} | 68 +-- .../series_reduction.cc | 2 +- lib/utils/test/src/utils/hash/multiset.cc | 34 ++ .../test/src/utils/hash/unordered_multiset.cc | 34 ++ lib/utils/test/src/utils/json/optional.cc | 49 ++ .../src/utils/{ => rapidcheck}/optional.cc | 7 +- 443 files changed, 8032 insertions(+), 1245 deletions(-) rename .github/workflows/helpers/{build_libs.sh => build_target.sh} (100%) rename .github/workflows/helpers/{test_libs.sh => test_target.sh} (100%) create mode 100644 bin/export-model-arch/CMakeLists.txt create mode 100644 bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml create mode 100644 bin/export-model-arch/src/export_model_arch.cc rename bin/{substitutions-to-dot => substitution-to-dot}/CMakeLists.txt (100%) rename bin/{substitutions-to-dot => substitution-to-dot}/substitution_to_dot.cc (89%) create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h create mode 100644 lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc create mode 100644 lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc create mode 100644 lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc create mode 100644 lib/models/include/models/split_test/split_test.h rename lib/models/include/models/{ => transformer}/transformer.h (90%) rename lib/models/include/models/{ => transformer}/transformer_config.struct.toml (100%) create mode 100644 lib/models/src/models/split_test/split_test.cc rename lib/models/src/models/{ => transformer}/transformer.cc (95%) create mode 100644 lib/op-attrs/include/op-attrs/datatype_value.variant.toml delete mode 100644 lib/op-attrs/test/src/op-attrs/tensor_shape.cc create mode 100644 lib/pcg/include/pcg/computation_graph/computation_graph_edge.h create mode 100644 lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml delete mode 100644 lib/pcg/include/pcg/file_format/file_format.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/graphs.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/v1.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml delete mode 100644 lib/pcg/src/file_format.cc delete mode 100644 lib/pcg/src/file_format/v1/graphs.cc create mode 100644 lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc create mode 100644 lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc create mode 100644 lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc create mode 100644 lib/utils/include/utils/cli/cli_argument_key.variant.toml create mode 100644 lib/utils/include/utils/cli/cli_flag_key.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_flag_spec.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_get_help_message.h create mode 100644 lib/utils/include/utils/cli/cli_parse.h create mode 100644 lib/utils/include/utils/cli/cli_parse_result.h create mode 100644 lib/utils/include/utils/cli/cli_parse_result.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_spec.h create mode 100644 lib/utils/include/utils/cli/cli_spec.struct.toml create mode 100644 lib/utils/include/utils/containers/foldl1.h create mode 100644 lib/utils/include/utils/containers/foldr1.h create mode 100644 lib/utils/include/utils/containers/maximum.h create mode 100644 lib/utils/include/utils/containers/multiset_union.h create mode 100644 lib/utils/include/utils/containers/require_no_duplicates.h create mode 100644 lib/utils/include/utils/containers/set_of.h create mode 100644 lib/utils/include/utils/containers/to_uppercase.h rename lib/utils/include/utils/containers/{as_vector.h => vector_of.h} (54%) create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h create mode 100644 lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h create mode 100644 lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/graph_generation.h (56%) create mode 100644 lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/intermediate_sp_decomposition_tree.struct.toml (90%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/parallel_reduction.h (70%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/parallel_reduction.struct.toml (100%) create mode 100644 lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h rename lib/utils/include/utils/graph/{serial_parallel/serial_parallel_decomposition.variant.toml => series_parallel/series_parallel_decomposition.variant.toml} (62%) rename lib/utils/include/utils/graph/{serial_parallel/serial_parallel_splits.h => series_parallel/series_parallel_splits.h} (59%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/series_reduction.h (77%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/series_reduction.struct.toml (100%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/sink_settings.enum.toml (100%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/source_settings.enum.toml (100%) create mode 100644 lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/split_type.enum.toml (90%) create mode 100644 lib/utils/include/utils/graph/undirected/algorithms/get_edges.h create mode 100644 lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h create mode 100644 lib/utils/include/utils/hash/multiset.h create mode 100644 lib/utils/include/utils/hash/unordered_multiset.h create mode 100644 lib/utils/include/utils/json/check_is_jsonable.h create mode 100644 lib/utils/include/utils/json/is_json_deserializable.h create mode 100644 lib/utils/include/utils/json/is_json_serializable.h create mode 100644 lib/utils/include/utils/json/is_jsonable.h create mode 100644 lib/utils/include/utils/json/optional.h create mode 100644 lib/utils/include/utils/json/variant.h rename lib/utils/include/utils/{json.h => json/visitable.h} (52%) create mode 100644 lib/utils/include/utils/rapidcheck/optional.h create mode 100644 lib/utils/src/utils/cli/cli_get_help_message.cc create mode 100644 lib/utils/src/utils/cli/cli_parse.cc create mode 100644 lib/utils/src/utils/cli/cli_parse_result.cc create mode 100644 lib/utils/src/utils/cli/cli_spec.cc delete mode 100644 lib/utils/src/utils/containers/as_vector.cc create mode 100644 lib/utils/src/utils/containers/enumerate_vector.cc create mode 100644 lib/utils/src/utils/containers/foldl1.cc create mode 100644 lib/utils/src/utils/containers/foldr1.cc create mode 100644 lib/utils/src/utils/containers/maximum.cc create mode 100644 lib/utils/src/utils/containers/multiset_union.cc create mode 100644 lib/utils/src/utils/containers/require_no_duplicates.cc create mode 100644 lib/utils/src/utils/containers/set_of.cc create mode 100644 lib/utils/src/utils/containers/to_uppercase.cc create mode 100644 lib/utils/src/utils/containers/vector_of.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc create mode 100644 lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc delete mode 100644 lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc rename lib/utils/src/utils/graph/{serial_parallel/get_serial_parallel_decomposition.cc => series_parallel/get_series_parallel_decomposition.cc} (62%) rename lib/utils/src/utils/graph/{serial_parallel => series_parallel}/graph_generation.cc (79%) create mode 100644 lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc rename lib/utils/src/utils/graph/{serial_parallel => series_parallel}/parallel_reduction.cc (93%) rename lib/utils/src/utils/graph/{serial_parallel/serial_parallel_decomposition.cc => series_parallel/series_parallel_decomposition.cc} (52%) rename lib/utils/src/utils/graph/{serial_parallel/serial_parallel_splits.cc => series_parallel/series_parallel_splits.cc} (65%) rename lib/utils/src/utils/graph/{serial_parallel => series_parallel}/series_reduction.cc (97%) create mode 100644 lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc create mode 100644 lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc create mode 100644 lib/utils/src/utils/hash/multiset.cc create mode 100644 lib/utils/src/utils/hash/unordered_multiset.cc create mode 100644 lib/utils/src/utils/json/check_is_jsonable.cc create mode 100644 lib/utils/src/utils/json/is_json_deserializable.cc create mode 100644 lib/utils/src/utils/json/is_json_serializable.cc create mode 100644 lib/utils/src/utils/json/is_jsonable.cc create mode 100644 lib/utils/src/utils/json/optional.cc create mode 100644 lib/utils/src/utils/rapidcheck/optional.cc delete mode 100644 lib/utils/test/common/include/test/utils/all.h rename lib/utils/test/common/include/test/utils/{doctest.h => doctest/check_without_stringify.h} (100%) create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/expected.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/map.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/optional.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/pair.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/set.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/variant.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/vector.h delete mode 100644 lib/utils/test/common/src/common.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/map.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/set.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc create mode 100644 lib/utils/test/src/utils/cli/cli_get_help_message.cc create mode 100644 lib/utils/test/src/utils/cli/cli_parse.cc create mode 100644 lib/utils/test/src/utils/containers/foldl1.cc create mode 100644 lib/utils/test/src/utils/containers/foldr1.cc create mode 100644 lib/utils/test/src/utils/containers/maximum.cc create mode 100644 lib/utils/test/src/utils/containers/multiset_union.cc create mode 100644 lib/utils/test/src/utils/containers/require_no_duplicates.cc create mode 100644 lib/utils/test/src/utils/containers/reversed.cc create mode 100644 lib/utils/test/src/utils/containers/to_uppercase.cc create mode 100644 lib/utils/test/src/utils/containers/vector_of.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc rename lib/utils/test/src/utils/graph/{serial_parallel/get_serial_parallel_decomposition.cc => series_parallel/get_series_parallel_decomposition.cc} (50%) rename lib/utils/test/src/utils/graph/{serial_parallel => series_parallel}/intermediate_sp_decomposition_tree.cc (83%) rename lib/utils/test/src/utils/graph/{serial_parallel => series_parallel}/parallel_reduction.cc (99%) rename lib/utils/test/src/utils/graph/{serial_parallel/serial_parallel_decomposition.cc => series_parallel/series_parallel_decomposition.cc} (66%) rename lib/utils/test/src/utils/graph/{serial_parallel => series_parallel}/series_reduction.cc (99%) create mode 100644 lib/utils/test/src/utils/hash/multiset.cc create mode 100644 lib/utils/test/src/utils/hash/unordered_multiset.cc create mode 100644 lib/utils/test/src/utils/json/optional.cc rename lib/utils/test/src/utils/{ => rapidcheck}/optional.cc (67%) diff --git a/.github/workflows/helpers/build_libs.sh b/.github/workflows/helpers/build_target.sh similarity index 100% rename from .github/workflows/helpers/build_libs.sh rename to .github/workflows/helpers/build_target.sh diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_target.sh similarity index 100% rename from .github/workflows/helpers/test_libs.sh rename to .github/workflows/helpers/test_target.sh diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 639f4d82b5..a5ac6fd29f 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -62,71 +62,79 @@ jobs: - name: Build utils run: | - build_libs.sh utils + build_target.sh utils - name: Build op-attrs run: | - build_libs.sh op-attrs + build_target.sh op-attrs - name: Build pcg run: | - build_libs.sh pcg + build_target.sh pcg - name: Build kernels run: | - build_libs.sh kernels + build_target.sh kernels - name: Build substitutions run: | - build_libs.sh substitutions + build_target.sh substitutions - name: Build compiler run: | - build_libs.sh compiler + build_target.sh compiler - name: Build substitution-generator run: | - build_libs.sh substitution-generator + build_target.sh substitution-generator - name: Build local-execution run: | - build_libs.sh local-execution + build_target.sh local-execution - name: Build models run: | - build_libs.sh models + build_target.sh models + + - name: Build substitution-to-dot + run: | + build_target.sh substitution-to-dot + + - name: Build export-model-arch + run: | + build_target.sh export-model-arch - name: Test utils run: | - test_libs.sh utils + test_target.sh utils - name: Test op-attrs run: | - test_libs.sh op-attrs + test_target.sh op-attrs - name: Test pcg run: | - test_libs.sh pcg + test_target.sh pcg - name: Test substitutions run: | - test_libs.sh substitutions + test_target.sh substitutions # - name: Test compiler # run: | - # test_libs.sh compiler + # test_target.sh compiler - name: Test substitution-generator run: | - test_libs.sh substitution-generator + test_target.sh substitution-generator - name: Test local-execution run: | - test_libs.sh local-execution + test_target.sh local-execution - name: Test models run: | - test_libs.sh models + test_target.sh models - name: Generate code coverage run: | diff --git a/.proj.toml b/.proj.toml index 721d212e31..5592f184ad 100644 --- a/.proj.toml +++ b/.proj.toml @@ -13,6 +13,8 @@ build_targets = [ "substitution-generator", "local-execution", "models", + "export-model-arch", + "substitution-to-dot", ] test_targets = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index a518931ac5..792126449b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" ON) option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) +option(FF_BUILD_BIN_EXPORT_MODEL_ARCH "build export-model-arch utility" ON) set(FF_CUDA_ARCH "autodetect" CACHE STRING "Target CUDA Arch") if (FF_CUDA_ARCH STREQUAL "") diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index fcc19b33b9..1cd7068cfd 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -7,9 +7,13 @@ if(FF_BUILD_SUBSTITUTION_TOOL) endif() if(FF_BUILD_VISUALIZATION_TOOL) - add_subdirectory(substitutions-to-dot) + add_subdirectory(substitution-to-dot) endif() if(FF_BUILD_ARG_PARSER) add_subdirectory(arg_parser) endif() + +if(FF_BUILD_BIN_EXPORT_MODEL_ARCH) + add_subdirectory(export-model-arch) +endif() diff --git a/bin/export-model-arch/CMakeLists.txt b/bin/export-model-arch/CMakeLists.txt new file mode 100644 index 0000000000..b931668594 --- /dev/null +++ b/bin/export-model-arch/CMakeLists.txt @@ -0,0 +1,12 @@ +ff_add_executable( + NAME + export-model-arch + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + include/ + DEPS + utils + models + compiler +) diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml new file mode 100644 index 0000000000..efaf368bc8 --- /dev/null +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "JsonSPModelExport" +features = [ + "eq", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/file_format/v1/v1_computation_graph.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h", +] + +[[fields]] +name = "sp_decomposition" +type = "::FlexFlow::GenericBinarySPDecompositionTree" + +[[fields]] +name = "computation_graph" +type = "::FlexFlow::V1ComputationGraph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc new file mode 100644 index 0000000000..ccc720ed14 --- /dev/null +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -0,0 +1,208 @@ +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/split_test/split_test.h" +#include "models/transformer/transformer.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "utils/cli/cli_get_help_message.h" +#include "utils/cli/cli_parse.h" +#include "utils/cli/cli_parse_result.h" +#include "utils/cli/cli_spec.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" + +using namespace ::FlexFlow; + +ComputationGraph get_single_operator_computation_graph() { + ComputationGraphBuilder b; + + size_t batch_size = 8; + size_t in_channels = 16; + size_t out_channels = 12; + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + batch_size, + in_channels, + out_channels, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + + InitializerAttrs kernel_initializer = + InitializerAttrs{GlorotUniformAttrs{/*seed=*/12}}; + InitializerAttrs bias_initializer = InitializerAttrs{ZeroInitializerAttrs{}}; + tensor_guid_t output = b.dense(input, + in_channels, + Activation::RELU, + /*use_bias=*/true, + DataType::FLOAT, + kernel_initializer, + bias_initializer, + "my_example_operator"); + + return b.computation_graph; +} + +ComputationGraph get_default_transformer_computation_graph() { + TransformerConfig config = get_default_transformer_config(); + ComputationGraph cg = get_transformer_computation_graph(config); + + return cg; +} + +tl::expected + get_model_computation_graph(std::string const &model_name) { + if (model_name == "transformer") { + return get_default_transformer_computation_graph(); + } else if (model_name == "split_test") { + int batch_size = 8; + return get_split_test_computation_graph(batch_size); + } else if (model_name == "single_operator") { + return get_single_operator_computation_graph(); + } else { + return tl::unexpected(fmt::format("Unknown model name: {}", model_name)); + } +} + +tl::expected + get_sp_model_export(std::string const &model_name) { + ComputationGraph computation_graph = ({ + tl::expected result = + get_model_computation_graph(model_name); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); + }); + + ComputationGraphBinarySPDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_right_assoc_binary_sp_decomposition( + computation_graph); + if (!result.has_value()) { + return tl::unexpected("Failed to generate series-parallel decomposition " + "of computation graph."); + } + result.value(); + }); + + std::pair> v1_result = + to_v1_including_node_numbering(computation_graph); + V1ComputationGraph v1_cg = v1_result.first; + bidict layer_numbering = v1_result.second; + GenericBinarySPDecompositionTree v1_sp_decomposition = + transform(sp_decomposition.raw_tree, + [&](layer_guid_t const &l) { return layer_numbering.at_r(l); }); + + return JsonSPModelExport{ + v1_sp_decomposition, + v1_cg, + }; +} + +int main(int argc, char **argv) { + CLISpec cli = empty_cli_spec(); + + CLIArgumentKey arg_key_help = cli_add_help_flag(cli); + + CLIArgumentKey key_sp_decomposition = + cli_add_flag(cli, + CLIFlagSpec{"sp-decomposition", + std::nullopt, + "also output a series parallel decomposition of " + "the model's computation graph"}); + + CLIArgumentKey key_dot = cli_add_flag( + cli, + CLIFlagSpec{ + "dot", + std::nullopt, + "output a dot representation of the model's computation graph"}); + + CLIArgumentKey key_preprocessed_dot = cli_add_flag( + cli, + CLIFlagSpec{"preprocessed-dot", + std::nullopt, + "output a dot representation of model's computation graph " + "for preprocessed to help check series-parallel structure"}); + + std::vector model_options = { + "transformer", "split_test", "single_operator"}; + CLIArgumentKey key_model_name = cli_add_positional_argument( + cli, + CLIPositionalArgumentSpec{ + "model", model_options, "name of the model to export"}); + + assert(argc >= 1); + std::string prog_name = argv[0]; + + CLIParseResult parsed = ({ + tl::expected result = + cli_parse(cli, argc, argv); + if (!result.has_value()) { + std::string error_msg = result.error(); + std::cerr << cli_get_help_message(prog_name, cli); + std::cerr << std::endl; + std::cerr << "error: " << error_msg << std::endl; + return 1; + } + + result.value(); + }); + + bool help = cli_get_flag(parsed, arg_key_help); + if (help) { + std::cerr << cli_get_help_message(prog_name, cli); + return 1; + } + + std::string model_name = cli_get_argument(parsed, key_model_name); + bool sp_decompositition = cli_get_flag(parsed, key_sp_decomposition); + bool dot = cli_get_flag(parsed, key_dot); + bool preprocessed_dot = cli_get_flag(parsed, key_preprocessed_dot); + + auto handle_error = [](auto const &result) { + if (!result.has_value()) { + std::cerr << "error: " << result.error() << std::endl; + exit(1); + } + + return result.value(); + }; + + if (dot) { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + + std::cout << as_dot(cg) << std::endl; + return 0; + } + + if (preprocessed_dot) { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + std::string rendered = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + + std::cout << rendered << std::endl; + return 0; + } + + nlohmann::json json_output; + if (sp_decompositition) { + JsonSPModelExport model_export = + handle_error(get_sp_model_export(model_name)); + + json_output = model_export; + } else { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + + json_output = to_v1(cg); + } + std::cout << json_output.dump(2) << std::endl; + + return 0; +} diff --git a/bin/substitutions-to-dot/CMakeLists.txt b/bin/substitution-to-dot/CMakeLists.txt similarity index 100% rename from bin/substitutions-to-dot/CMakeLists.txt rename to bin/substitution-to-dot/CMakeLists.txt diff --git a/bin/substitutions-to-dot/substitution_to_dot.cc b/bin/substitution-to-dot/substitution_to_dot.cc similarity index 89% rename from bin/substitutions-to-dot/substitution_to_dot.cc rename to bin/substitution-to-dot/substitution_to_dot.cc index 49a199ddd3..1b5f715bcd 100644 --- a/bin/substitutions-to-dot/substitution_to_dot.cc +++ b/bin/substitution-to-dot/substitution_to_dot.cc @@ -1,4 +1,4 @@ -#include "substitution-generator/json.h" +#include "substitution-generator/legacy_rules.h" #include "utils/dot_file.h" #include #include @@ -24,10 +24,11 @@ int main(int argc, char **argv) { std::string json_path(argv[1]); std::string rule_name(argv[2]); - RuleCollection rule_collection = load_rule_collection_from_path(json_path); + LegacyRuleCollection rule_collection = + load_rule_collection_from_path(json_path); - std::optional found = std::nullopt; - for (Rule const &r : rule_collection.rules) { + std::optional found = std::nullopt; + for (LegacyRule const &r : rule_collection.rules) { if (r.name == rule_name) { found = r; break; @@ -39,7 +40,7 @@ int main(int argc, char **argv) { return 1; } - Rule r = found.value(); + LegacyRule r = found.value(); using Node = std::tuple; @@ -82,14 +83,14 @@ int main(int argc, char **argv) { }; for (int i = 0; i < r.srcOp.size(); i++) { - Operator const &o = r.srcOp[i]; + LegacyOperator const &o = r.srcOp[i]; Node srcOpNode = {NodeType::SRC, i, 0}; { dot.add_node(srcOpNode, label_map(fmt::to_string(o.op_type), srcOpNode)); dot.add_node_to_subgraph(srcOpNode, src_body_subgraph); } - for (Tensor const &t : o.input) { + for (LegacyTensor const &t : o.input) { if (t.opId < 0) { assert(t.tsId == 0); Node inputOpNode = {NodeType::SRC_INPUT_TENSOR, t.opId, 0}; @@ -106,14 +107,14 @@ int main(int argc, char **argv) { } } for (int j = 0; j < r.dstOp.size(); j++) { - Operator const &o = r.dstOp[j]; + LegacyOperator const &o = r.dstOp[j]; Node dstOpNode = {NodeType::DST, j, 0}; { dot.add_node(dstOpNode, label_map(fmt::to_string(o.op_type), dstOpNode)); dot.add_node_to_subgraph(dstOpNode, dst_body_subgraph); } - for (Tensor const &t : o.input) { + for (LegacyTensor const &t : o.input) { if (t.opId < 0) { assert(t.tsId == 0); Node inputOpNode = {NodeType::DST_INPUT_TENSOR, t.opId, 0}; @@ -128,7 +129,7 @@ int main(int argc, char **argv) { } } } - for (MapOutput const &mo : r.mappedOutput) { + for (LegacyMapOutput const &mo : r.mappedOutput) { Node srcOutputNode = {NodeType::SRC_OUTPUT_TENSOR, mo.srcOpId, mo.srcTsId}; Node dstOutputNode = {NodeType::DST_OUTPUT_TENSOR, mo.dstOpId, mo.dstTsId}; { diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 1dbd16bdb1..90e100bb1b 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -149,6 +149,11 @@ function(ff_add_executable) ${FF_EXEC_NAME} ${SRC}) + target_include_directories( + ${FF_EXEC_NAME} + PRIVATE + ${FF_EXEC_PRIVATE_INCLUDE}) + target_link_libraries( ${FF_EXEC_NAME} ${FF_EXEC_DEPS}) diff --git a/lib/compiler/include/compiler/graph_utils.h b/lib/compiler/include/compiler/graph_utils.h index 1370357837..75fd369434 100644 --- a/lib/compiler/include/compiler/graph_utils.h +++ b/lib/compiler/include/compiler/graph_utils.h @@ -5,12 +5,12 @@ #include "pcg/computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg); +SeriesParallelDecomposition + get_series_parallel_decomposition(ParallelComputationGraph const &pcg); ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 5d17cbb373..3774f2cd52 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -9,7 +9,8 @@ #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "substitutions/sub_parallel_computation_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/visitable.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/optimal_cost_state.struct.toml index 50496f661b..036647c0b1 100644 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ b/lib/compiler/include/compiler/optimal_cost_state.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h", + "utils/graph/series_parallel/series_parallel_decomposition.dtg.h", "pcg/machine_specification.dtg.h", "pcg/machine_view.dtg.h", "utils/graph/node/node.dtg.h", @@ -21,7 +21,7 @@ includes = [ [[fields]] name = "subgraph" -type = "::FlexFlow::SerialParallelDecomposition" +type = "::FlexFlow::SeriesParallelDecomposition" [[fields]] name = "resource" @@ -33,4 +33,4 @@ type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" [[fields]] name = "frontier_machine_views" -type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" \ No newline at end of file +type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h new file mode 100644 index 0000000000..3032e3efe9 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.dtg.h" +#include "pcg/computation_graph.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &); +ComputationGraphBinarySPDecomposition + get_left_child(ComputationGraphBinarySPDecomposition const &); +ComputationGraphBinarySPDecomposition + get_right_child(ComputationGraphBinarySPDecomposition const &); +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &); +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &); +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &); +bool is_left_associative(ComputationGraphBinarySPDecomposition const &); +bool is_right_associative(ComputationGraphBinarySPDecomposition const &); +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml new file mode 100644 index 0000000000..147b1e3acf --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySPDecomposition" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h new file mode 100644 index 0000000000..e85843ed26 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_COMPUTATION_GRAPH_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_COMPUTATION_GRAPH_SERIES_PARALLEL_DECOMPOSITION_H + +#include "pcg/computation_graph.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::string render_preprocessed_computation_graph_for_sp_decomposition( + ComputationGraph const &); +std::optional + get_computation_graph_series_parallel_decomposition( + ComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc new file mode 100644 index 0000000000..63054385ac --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc @@ -0,0 +1,90 @@ +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &d) { + return get_node_type(d.raw_tree); +} + +ComputationGraphBinarySPDecomposition + get_left_child(ComputationGraphBinarySPDecomposition const &d) { + return ComputationGraphBinarySPDecomposition{ + get_left_child(d.raw_tree), + }; +} + +ComputationGraphBinarySPDecomposition + get_right_child(ComputationGraphBinarySPDecomposition const &d) { + return ComputationGraphBinarySPDecomposition{ + get_right_child(d.raw_tree), + }; +} + +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { + return require_node(d.raw_tree); +} + +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + left_associative_binary_sp_tree_from_nary(sp_decomposition); + + return ComputationGraphBinarySPDecomposition{transform( + raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; +} + +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + right_associative_binary_sp_tree_from_nary(sp_decomposition); + + return ComputationGraphBinarySPDecomposition{transform( + raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; +} + +bool is_left_associative(ComputationGraphBinarySPDecomposition const &d) { + return is_binary_sp_tree_left_associative(d.raw_tree); +} + +bool is_right_associative(ComputationGraphBinarySPDecomposition const &d) { + return is_binary_sp_tree_right_associative(d.raw_tree); +} + +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &d) { + return get_leaves(d.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc new file mode 100644 index 0000000000..184ad93f4d --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -0,0 +1,98 @@ +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph/computation_graph_edge.h" +#include "utils/graph/digraph/algorithms/digraph_as_dot.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +std::string render_preprocessed_computation_graph_for_sp_decomposition( + ComputationGraph const &cg) { + std::unordered_set weight_and_input_layers = + filter(get_layers(cg), [&](layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs; + return op_attrs.has() || op_attrs.has(); + }); + + std::unordered_set weight_and_input_layer_successors = + get_subgraph_successors(cg, weight_and_input_layers); + + // dot has is incapable of rendering the number of edges in the all-to-all + // connection, so for visualization purposes we instead insert a "fake" node + // to reduce the n^2 edges to 2*n edges + DiGraph preprocessed_digraph = + materialize_digraph_view(cg.raw_graph); + Node fake_node = preprocessed_digraph.add_node(); + for (layer_guid_t const &src : weight_and_input_layers) { + preprocessed_digraph.add_edge(DirectedEdge{src.raw_node, fake_node}); + } + for (layer_guid_t const &dst : weight_and_input_layer_successors) { + preprocessed_digraph.add_edge(DirectedEdge{fake_node, dst.raw_node}); + } + + std::function get_node_label = + [&](Node const &n) -> std::string { + if (n == fake_node) { + return "FAKE"; + } + LayerAttrs a = cg.raw_graph.at(n); + RecordFormatter r = as_dot(a.attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + std::string preprocessed_dot = digraph_as_dot( + transitive_reduction(preprocessed_digraph), get_node_label); + + return preprocessed_dot; +} + +std::optional + get_computation_graph_series_parallel_decomposition( + ComputationGraph const &cg) { + + { + DiGraphView unpreprocessed_digraph = cg.raw_graph; + std::optional unpreprocessed_sp_decomposition = + get_series_parallel_decomposition(unpreprocessed_digraph); + if (unpreprocessed_sp_decomposition.has_value()) { + return unpreprocessed_sp_decomposition.value(); + } + } + + DiGraphView preprocessed_digraph = [&] { + std::unordered_set weight_and_input_layers = + filter(get_layers(cg), [&](layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs; + return op_attrs.has() || op_attrs.has(); + }); + + std::unordered_set weight_and_input_layer_successors = + get_subgraph_successors(cg, weight_and_input_layers); + + DiGraph digraph = materialize_digraph_view(cg.raw_graph); + for (layer_guid_t const &src : weight_and_input_layers) { + for (layer_guid_t const &dst : weight_and_input_layer_successors) { + digraph.add_edge(DirectedEdge{src.raw_node, dst.raw_node}); + } + } + + return digraph; + }(); + + return get_series_parallel_decomposition(preprocessed_digraph); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 08db219a21..a19c5e8597 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -4,13 +4,13 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "utils/containers/without_order.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg) { +SeriesParallelDecomposition + get_series_parallel_decomposition(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); - // return get_serial_parallel_decomposition(pcg.raw_graph); + // return get_series_parallel_decomposition(pcg.raw_graph); } ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { @@ -126,11 +126,11 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { // } // }; -// std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { +// std::unordered_set get_nodes(SeriesParallelDecomposition const &sp) { // return std::visit(GetNodes{}, sp.raw_variant); // } -// std::unordered_set get_nodes(SerialSplit const &serial) { +// std::unordered_set get_nodes(SeriesSplit const &serial) { // return set_union( // transform(serial.children, [](std::variant const // child) { @@ -140,7 +140,7 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { // std::unordered_set get_nodes(ParallelSplit const ¶llel) { // return set_union( -// transform(parallel.children, [](std::variant const +// transform(parallel.children, [](std::variant const // child) { // return std::visit(GetNodes{}, child); // })); diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index af7756c635..fddd825109 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -8,18 +8,19 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers.h" #include "utils/containers/are_disjoint.h" -#include "utils/containers/as_vector.h" #include "utils/containers/contains_key.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" #include "utils/containers/merge_maps.h" +#include "utils/containers/require_no_duplicates.h" +#include "utils/containers/vector_of.h" #include "utils/exception.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" namespace FlexFlow { @@ -83,39 +84,43 @@ std::vector> } // We may replace this by having unflattened AST -std::pair - decompose(SerialSplit const &serial) { +std::pair + decompose(SeriesSplit const &serial) { if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; + return {widen(serial.children[0]), + widen(serial.children[1])}; } - SerialSplit decompn1 = serial; + SeriesSplit decompn1 = serial; decompn1.children.pop_back(); - return {SerialParallelDecomposition(decompn1), - widen(serial.children.back())}; + return {SeriesParallelDecomposition(decompn1), + widen(serial.children.back())}; } -std::pair +std::pair decompose(ParallelSplit const ¶llel) { if (parallel.children.size() == 2) { - std::vector children = - transform(as_vector(parallel.children), [&](auto const &child) { - return widen(child); + std::vector children = + transform(vector_of(parallel.children), [&](auto const &child) { + return widen(child); }); return {children[0], children[1]}; } ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); + std::variant child = *parallel.children.begin(); decompn1.children.erase(child); - return {SerialParallelDecomposition(decompn1), - widen(child)}; + return {SeriesParallelDecomposition(decompn1), + widen(child)}; } GraphSplit - get_graph_split(SerialParallelDecomposition const &pre_decomposition, - SerialParallelDecomposition const &post_decomposition) { - return GraphSplit{get_nodes(pre_decomposition), - get_nodes(post_decomposition)}; + get_graph_split(SeriesParallelDecomposition const &pre_decomposition, + SeriesParallelDecomposition const &post_decomposition) { + std::unordered_set pre_nodes = + require_no_duplicates(get_nodes(pre_decomposition)); + std::unordered_set post_nodes = + require_no_duplicates(get_nodes(post_decomposition)); + assert(are_disjoint(pre_nodes, post_nodes)); + return GraphSplit{pre_nodes, post_nodes}; } float estimate_cost(SubParallelComputationGraph const &g, @@ -181,7 +186,7 @@ struct MachineMappingSearcher { template OptimalCostResult operator()(T const &t) { - OptimalCostState state{SerialParallelDecomposition{t}, + OptimalCostState state{SeriesParallelDecomposition{t}, resource, given_machine_views, frontier_machine_views}; @@ -202,13 +207,13 @@ struct MachineMappingSearcher { OptimalCostResult optimal_cost(SubParallelComputationGraph const &g, MachineSpecification resource, - SerialParallelDecomposition const &sp_decomposition) { + SeriesParallelDecomposition const &sp_decomposition) { return std::visit(OptimalCostFunctor(this, g, resource, {}, {}), sp_decomposition.raw_variant); } OptimalCostResult optimal_cost( - SerialSplit const &serial, + SeriesSplit const &serial, SubParallelComputationGraph const &g, MachineSpecification const &resource, std::unordered_map const &given_machine_views, @@ -218,8 +223,8 @@ struct MachineMappingSearcher { // OptimalCostResult optimal_result = OptimalCostResult::infinity(); // auto decomposed = decompose(serial); - // SerialParallelDecomposition pre_decompn = decomposed.first; - // SerialParallelDecomposition post_decompn = decomposed.second; + // SeriesParallelDecomposition pre_decompn = decomposed.first; + // SeriesParallelDecomposition post_decompn = decomposed.second; // GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); // SubParallelComputationGraph pre_graph = @@ -273,8 +278,8 @@ struct MachineMappingSearcher { NOT_IMPLEMENTED(); // auto decomposed = decompose(parallel); - // SerialParallelDecomposition decompn1 = decomposed.first; - // SerialParallelDecomposition decompn2 = decomposed.second; + // SeriesParallelDecomposition decompn1 = decomposed.first; + // SeriesParallelDecomposition decompn2 = decomposed.second; // GraphSplit graph_split = get_graph_split(decompn1, decompn2); // SubParallelComputationGraph g1 = get_subgraph(g, graph_split.first), @@ -350,8 +355,8 @@ OptimalCostResult optimal_cost( CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - SerialParallelDecomposition sp_decomposition = - get_serial_parallel_decomposition(g); + SeriesParallelDecomposition sp_decomposition = + get_series_parallel_decomposition(g); SubParallelComputationGraph subpcg = pcg_to_subpcg(g); MachineMappingSearcher searcher( cost_estimator, allowed_machine_views, cached_subgraph_costs); diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index 13b1fd3b83..3399a45f0f 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -10,4 +10,5 @@ ff_add_test_executable( compiler doctest utils-test-common + models ) diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc new file mode 100644 index 0000000000..ab537e73de --- /dev/null +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -0,0 +1,340 @@ +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/split_test/split_test.h" +#include "models/transformer/transformer.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "get_computation_graph_series_parallel_decomposition(ComputationGraph)") { + SUBCASE("empty computation graph") { + ComputationGraph cg = make_empty_computation_graph(); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + // technically an empty graph is non-SP + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("just a single input") { + std::string input_layer_name = "my input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + b.create_input(input_shape, CreateGrad::YES, input_layer_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_layer_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{input_layer.raw_node}; + + CHECK(result == correct); + } + + SUBCASE("single operator plus inputs and weights") { + std::string input_layer_name = "my input"; + std::string projection_weights_layer_name = "my projection weights"; + std::string bias_weights_layer_name = "my bias weights"; + std::string operator_name = "my operator"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_layer_name); + + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/operator_name, + /*projection_name=*/projection_weights_layer_name, + /*bias_name=*/bias_weights_layer_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_layer_name); + layer_guid_t projection_weights_layer = + get_layer_by_name(cg, projection_weights_layer_name); + layer_guid_t bias_weights_layer = + get_layer_by_name(cg, bias_weights_layer_name); + layer_guid_t operator_layer = get_layer_by_name(cg, operator_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ + ParallelSplit{ + input_layer.raw_node, + projection_weights_layer.raw_node, + bias_weights_layer.raw_node, + }, + operator_layer.raw_node, + }}; + + CHECK(result == correct); + } + + SUBCASE("SP without weight nodes but non-SP with weight nodes") { + // A minimal computation graph where without weights (w1 and w2) the + // computation graph is series-parallel, but with weight nodes it is not + // + // w1 input w2 + // \ / \ / + // op1 op2 + + std::string w1_name = "w1"; + std::string input_name = "input"; + std::string w2_name = "w2"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/op1_name, + /*projection_name=*/w1_name); + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/op2_name, + /*projection_name=*/w2_name); + + return b.computation_graph; + }(); + + layer_guid_t w1 = get_layer_by_name(cg, w1_name); + layer_guid_t input = get_layer_by_name(cg, input_name); + layer_guid_t w2 = get_layer_by_name(cg, w2_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ + ParallelSplit{ + w1.raw_node, + input.raw_node, + w2.raw_node, + }, + ParallelSplit{ + op1.raw_node, + op2.raw_node, + }, + }}; + } + + SUBCASE("SP with or without preprocessing, but preprocessing would SP " + "decomposition") { + // computation graph: + // + // input1 input2 + // | | + // op1 op2 + + std::string input1_name = "input1"; + std::string input2_name = "input2"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input1 = + b.create_input(input_shape, CreateGrad::YES, input1_name); + tensor_guid_t input2 = + b.create_input(input_shape, CreateGrad::YES, input2_name); + + b.relu(input1, op1_name); + b.relu(input2, op2_name); + + return b.computation_graph; + }(); + + layer_guid_t input1 = get_layer_by_name(cg, input1_name); + layer_guid_t input2 = get_layer_by_name(cg, input2_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{ParallelSplit{ + SeriesSplit{ + input1.raw_node, + op1.raw_node, + }, + SeriesSplit{ + input2.raw_node, + op2.raw_node, + }, + }}; + } + + SUBCASE("not SP with or without weight nodes") { + // computation graph: + // + // input1 + // / \ + // op1 op2 + // | \ | + // | \ | + // op3 op4 + + std::string input1_name = "input1"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + std::string op3_name = "op3"; + std::string op4_name = "op4"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input1 = + b.create_input(input_shape, CreateGrad::YES, input1_name); + + tensor_guid_t op1_output = b.relu(input1, op1_name); + tensor_guid_t op2_output = b.relu(input1, op2_name); + b.relu(op1_output, op3_name); + b.add(op1_output, op2_output, op4_name); + + return b.computation_graph; + }(); + + layer_guid_t input1 = get_layer_by_name(cg, input1_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + layer_guid_t op3 = get_layer_by_name(cg, op3_name); + layer_guid_t op4 = get_layer_by_name(cg, op4_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = std::nullopt; + } + + SUBCASE("real models") { + SUBCASE("split_test") { + ComputationGraph cg = + get_split_test_computation_graph(/*batch_size=*/8); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + + SUBCASE("transformer") { + ComputationGraph cg = + get_transformer_computation_graph(get_default_transformer_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + } + } + + TEST_CASE("render_preprocessed_computation_graph_for_sp_decomposition(" + "ComputationGraph)") { + // currently there's not really a good way to test this, and its arguable + // how much its output really should be validated as its primarily for + // visualization and so there's not really a strict definition of + // correctness, so for now we just run it on some models and make sure it + // doesn't crash. Don't use this as an example. + + SUBCASE("basic single-operator model") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + + b.dense(input, /*outDim=*/14); + + return b.computation_graph; + }(); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("split_test") { + ComputationGraph cg = get_split_test_computation_graph(/*batch_size=*/8); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("transformer") { + ComputationGraph cg = + get_transformer_computation_graph(get_default_transformer_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + } +} diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index d6b8222968..9f5a768b27 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -46,7 +46,7 @@ using namespace FlexFlow; // namespace rc { // Gen serialParallelMultiDiGraph() { -// return gen::map(gen::arbitrary(), +// return gen::map(gen::arbitrary(), // multidigraph_from_sp_decomposition); // } @@ -113,12 +113,12 @@ using namespace FlexFlow; // }; // template <> -// struct Arbitrary { -// static Gen arbitrary() { +// struct Arbitrary { +// static Gen arbitrary() { // return gen::mapcat(gen::arbitrary(), [](bool is_serial) { -// return is_serial ? gen::construct( +// return is_serial ? gen::construct( // gen::arbitrary()) -// : gen::construct( +// : gen::construct( // gen::arbitrary()); // }); // } diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index d3221474c0..f523520f9f 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -16,7 +16,7 @@ class GenericTensorAccessorW { template typename data_type_enum_to_class
::type *get() const { if (this->data_type == DT) { - return static_cast *>(this->ptr); + return static_cast *>(this->ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", this->data_type, DT); @@ -47,7 +47,7 @@ class GenericTensorAccessorR { template typename data_type_enum_to_class
::type const *get() const { if (this->data_type == DT) { - return static_cast const *>(this->ptr); + return static_cast const *>(this->ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", this->data_type, DT); @@ -94,7 +94,7 @@ template typename data_type_enum_to_class
::type * get(GenericTensorAccessorW const &a) { if (a.data_type == DT) { - return static_cast *>(a.ptr); + return static_cast *>(a.ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", a.data_type, DT); @@ -102,9 +102,9 @@ typename data_type_enum_to_class
::type * } template -std::vector *> +std::vector *> get(std::vector const &accs) { - std::vector *> out; + std::vector *> out; for (auto acc : accs) { out.push_back(get
(acc)); } @@ -115,7 +115,7 @@ template typename data_type_enum_to_class
::type const * get(GenericTensorAccessorR const &a) { if (a.data_type == DT) { - return static_cast const *>(a.ptr); + return static_cast const *>(a.ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", a.data_type, DT); @@ -139,9 +139,9 @@ std::vector get_half_ptrs(std::vector const &); template -std::vector const *> +std::vector const *> get(std::vector const &accs) { - std::vector const *> out; + std::vector const *> out; for (auto acc : accs) { out.push_back(get
(acc)); } diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 5de9fae7ad..96a3b3b281 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_KERNELS_ARRAY_SHAPE_H #include "legion_dim.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" #include "utils/visitable.h" #include diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index 575de57f09..eb5a1b8198 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -5,7 +5,6 @@ #include "kernels/allocation.h" #include "kernels/device.h" #include "kernels/ff_handle.h" -#include "op-attrs/ops/attention.h" #include namespace FlexFlow { diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 38be2118fa..bfd72647b0 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -4,7 +4,6 @@ #include "device.h" #include "kernels/allocation.h" #include "kernels/ff_handle.h" -#include "utils/visitable.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/include/kernels/initializer_kernels.h b/lib/kernels/include/kernels/initializer_kernels.h index 14bb9d2cd2..52609a303f 100644 --- a/lib/kernels/include/kernels/initializer_kernels.h +++ b/lib/kernels/include/kernels/initializer_kernels.h @@ -3,6 +3,7 @@ #include "accessor.h" #include "kernels/cpu.h" +#include "op-attrs/datatype_value.dtg.h" #include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/src/allocation.cc b/lib/kernels/src/allocation.cc index a892e14a54..ccd88580db 100644 --- a/lib/kernels/src/allocation.cc +++ b/lib/kernels/src/allocation.cc @@ -1,4 +1,5 @@ #include "kernels/allocation.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/kernels/src/cpu/initializer_kernels.cc b/lib/kernels/src/cpu/initializer_kernels.cc index f3b4c9b8fd..91f4f46ef8 100644 --- a/lib/kernels/src/cpu/initializer_kernels.cc +++ b/lib/kernels/src/cpu/initializer_kernels.cc @@ -24,7 +24,7 @@ struct ConstantInitKernel { void operator()(GenericTensorAccessorW const &tensor, DataTypeValue value) const { auto arr = get
(tensor); - auto unwrapped_value = get>(value); + auto unwrapped_value = value.get>(); for (size_t i = 0; i < get_volume(tensor.shape); i++) { arr[i] = unwrapped_value; } diff --git a/lib/kernels/src/cuda/embedding_kernels.cu b/lib/kernels/src/cuda/embedding_kernels.cu index 371b45f760..e6a614ba70 100644 --- a/lib/kernels/src/cuda/embedding_kernels.cu +++ b/lib/kernels/src/cuda/embedding_kernels.cu @@ -358,7 +358,7 @@ struct ForwardKernel { weight.data_type == DataType::DOUBLE); if (!aggr.has_value()) { - embed_forward_no_aggr, real_type> + embed_forward_no_aggr, real_type_t> <<, real_type> + embed_forward_with_aggr, real_type_t> <<, real_type> + embed_backward_no_aggr, real_type_t> <<, real_type> + embed_backward_with_aggr, real_type_t> <<> + add_kernel> <<>>( input_grad.get
(), output_grad.get
(), num_elements); } diff --git a/lib/kernels/src/cuda/ops/element_unary_kernels.cu b/lib/kernels/src/cuda/ops/element_unary_kernels.cu index 3eb9c486f2..a35d28fa8c 100644 --- a/lib/kernels/src/cuda/ops/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_unary_kernels.cu @@ -267,16 +267,16 @@ struct ForwardKernel { } else if (use_scalar(op_type)) { assert(scalar.has_value()); size_t num_elements = input.shape.num_elements(); - elewise_scalar_unary_forward_kernel> + elewise_scalar_unary_forward_kernel> <<>>( num_elements, - static_cast>(scalar.value()), + static_cast>(scalar.value()), op_type, input.get(), output.get()); } else { size_t num_elements = input.shape.num_elements(); - elewise_unary_forward_kernel> + elewise_unary_forward_kernel> <<>>( num_elements, op_type, input.get(), output.get()); } @@ -313,10 +313,10 @@ struct BackwardKernel { } else if (use_scalar(op_type)) { assert(scalar.has_value()); size_t num_elements = input.shape.num_elements(); - elewise_scalar_unary_backward_kernel> + elewise_scalar_unary_backward_kernel> <<>>( num_elements, - static_cast>(scalar.value()), + static_cast>(scalar.value()), op_type, output.get(), output_grad.get(), @@ -324,7 +324,7 @@ struct BackwardKernel { input_grad.get()); } else { size_t num_elements = input.shape.num_elements(); - elewise_unary_backward_kernel> + elewise_unary_backward_kernel> <<>>( num_elements, op_type, diff --git a/lib/kernels/src/cuda/ops/partition_kernels.cu b/lib/kernels/src/cuda/ops/partition_kernels.cu index e356f83d2a..1d07efb5fa 100644 --- a/lib/kernels/src/cuda/ops/partition_kernels.cu +++ b/lib/kernels/src/cuda/ops/partition_kernels.cu @@ -41,12 +41,12 @@ struct BackwardKernel { RepartitionPerDeviceState const &m, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad) { - add_kernel><<>>(input_grad.get(), - output_grad.get(), - input_grad.shape.num_elements()); + add_kernel><<>>(input_grad.get(), + output_grad.get(), + input_grad.shape.num_elements()); } }; diff --git a/lib/kernels/src/cuda/ops/reduction_kernels.cu b/lib/kernels/src/cuda/ops/reduction_kernels.cu index 992d27fe60..0c6ba7d8e3 100644 --- a/lib/kernels/src/cuda/ops/reduction_kernels.cu +++ b/lib/kernels/src/cuda/ops/reduction_kernels.cu @@ -42,7 +42,7 @@ struct ForwardKernel { size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - reduction_forward_kernel> + reduction_forward_kernel> <<>>( input.get(), output.get(), diff --git a/lib/kernels/src/cuda/ops/replicate_kernels.cu b/lib/kernels/src/cuda/ops/replicate_kernels.cu index 0c87418f58..76bfbe2658 100644 --- a/lib/kernels/src/cuda/ops/replicate_kernels.cu +++ b/lib/kernels/src/cuda/ops/replicate_kernels.cu @@ -54,7 +54,7 @@ struct BackwardKernel { GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - replicate_backward_kernel> + replicate_backward_kernel> <<>>( input.get(), output.get(), diff --git a/lib/kernels/src/cuda/ops/reshape_kernels.cu b/lib/kernels/src/cuda/ops/reshape_kernels.cu index c4da408952..5b7843a3a5 100644 --- a/lib/kernels/src/cuda/ops/reshape_kernels.cu +++ b/lib/kernels/src/cuda/ops/reshape_kernels.cu @@ -45,14 +45,14 @@ struct BackwardKernel { GenericTensorAccessorW const &input, GenericTensorAccessorR const &output) { float alpha = 1.0f; - apply_add_with_scale> + apply_add_with_scale> <<>>(input.get(), output.get(), input.shape.num_elements(), - static_cast>(alpha)); + static_cast>(alpha)); } }; diff --git a/lib/kernels/src/hip/ops/replicate_kernels.cpp b/lib/kernels/src/hip/ops/replicate_kernels.cpp index 9a5fc813c3..8d27bb1908 100644 --- a/lib/kernels/src/hip/ops/replicate_kernels.cpp +++ b/lib/kernels/src/hip/ops/replicate_kernels.cpp @@ -55,15 +55,16 @@ struct BackwardKernel { GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - hipLaunchKernelGGL(HIP_KERNEL_NAME(replicate_backward_kernel>), - GET_BLOCKS(total_elements), - CUDA_NUM_THREADS, - 0, - stream, - input.get(), - output.get(), - input.shape.num_elements(), - num_replicas); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(replicate_backward_kernel>), + GET_BLOCKS(total_elements), + CUDA_NUM_THREADS, + 0, + stream, + input.get(), + output.get(), + input.shape.num_elements(), + num_replicas); } } diff --git a/lib/kernels/src/hip/ops/reshape_kernels.cpp b/lib/kernels/src/hip/ops/reshape_kernels.cpp index 941495c0fd..47978a5f4a 100644 --- a/lib/kernels/src/hip/ops/reshape_kernels.cpp +++ b/lib/kernels/src/hip/ops/reshape_kernels.cpp @@ -47,7 +47,7 @@ struct BackwardKernel { GenericTensorAccessorW const &input, GenericTensorAccessorR const &output) { float alpha = 1.0f; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale>), + hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale>), GET_BLOCKS(input.shape.num_elements()), CUDA_NUM_THREADS, 0, @@ -55,7 +55,7 @@ struct BackwardKernel { input.get(), output.get(), input.shape.num_elements(), - static_cast> alpha); + static_cast> alpha); } } diff --git a/lib/local-execution/include/local-execution/cost_estimate.h b/lib/local-execution/include/local-execution/cost_estimate.h index 33954827bd..31503e0da9 100644 --- a/lib/local-execution/include/local-execution/cost_estimate.h +++ b/lib/local-execution/include/local-execution/cost_estimate.h @@ -4,8 +4,8 @@ #include "local-execution/cost_details.dtg.h" #include "local-execution/local_training_backing.h" -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" diff --git a/lib/local-execution/include/local-execution/legion_tensor_shape.h b/lib/local-execution/include/local-execution/legion_tensor_shape.h index f1d2ad252a..2f2ed50d41 100644 --- a/lib/local-execution/include/local-execution/legion_tensor_shape.h +++ b/lib/local-execution/include/local-execution/legion_tensor_shape.h @@ -4,8 +4,9 @@ #include "kernels/legion_dim.h" #include "op-attrs/datatype.h" #include "op-attrs/ff_dim.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" +#include "utils/visitable.h" #include namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/local_slots_backing.h b/lib/local-execution/include/local-execution/local_slots_backing.h index 6a0c28e988..5b826c7022 100644 --- a/lib/local-execution/include/local-execution/local_slots_backing.h +++ b/lib/local-execution/include/local-execution/local_slots_backing.h @@ -7,6 +7,9 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/per_device_op_state.h" #include "local-execution/runtime_arg_config.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/local_training_backing.h b/lib/local-execution/include/local-execution/local_training_backing.h index b398bb8cc3..6789624076 100644 --- a/lib/local-execution/include/local-execution/local_training_backing.h +++ b/lib/local-execution/include/local-execution/local_training_backing.h @@ -3,6 +3,7 @@ #include "local-execution/local_slots_backing.h" #include "local-execution/task_registry.h" +#include "pcg/computation_graph.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/op_arg_ref.h b/lib/local-execution/include/local-execution/op_arg_ref.h index 20d6ccb1c5..102a8d4362 100644 --- a/lib/local-execution/include/local-execution/op_arg_ref.h +++ b/lib/local-execution/include/local-execution/op_arg_ref.h @@ -5,7 +5,7 @@ #include "local-execution/device_specific.h" #include "local-execution/op_arg_ref_type.dtg.h" #include "local-execution/per_device_op_state.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/op_task_invocation.h b/lib/local-execution/include/local-execution/op_task_invocation.h index 73a0460554..0f351c3a0e 100644 --- a/lib/local-execution/include/local-execution/op_task_invocation.h +++ b/lib/local-execution/include/local-execution/op_task_invocation.h @@ -13,10 +13,6 @@ #include "local-execution/slot_grad_id.dtg.h" #include "local-execution/task_id_t.dtg.h" #include "local-execution/variadic_tensor_ref.h" -#include "op-attrs/computation_graph_op_attrs.h" -#include "pcg/computation_graph.h" -#include "utils/bidict/bidict.h" -#include "utils/stack_map.h" #include #include #include diff --git a/lib/local-execution/include/local-execution/sim_environment.h b/lib/local-execution/include/local-execution/sim_environment.h index 3ba17ea3ff..7c81cba408 100644 --- a/lib/local-execution/include/local-execution/sim_environment.h +++ b/lib/local-execution/include/local-execution/sim_environment.h @@ -7,7 +7,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/task_argument_accessor.h" #include "local-execution/task_signature_impl.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/machine_view.h" #include diff --git a/lib/local-execution/include/local-execution/task_registry.struct.toml b/lib/local-execution/include/local-execution/task_registry.struct.toml index 308527efac..ada467a67d 100644 --- a/lib/local-execution/include/local-execution/task_registry.struct.toml +++ b/lib/local-execution/include/local-execution/task_registry.struct.toml @@ -15,6 +15,7 @@ includes = [ src_includes = [ "utils/hash/unordered_map.h", "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", ] [[fields]] diff --git a/lib/local-execution/src/legion_tensor_shape.cc b/lib/local-execution/src/legion_tensor_shape.cc index b3a045bab4..bce29fafeb 100644 --- a/lib/local-execution/src/legion_tensor_shape.cc +++ b/lib/local-execution/src/legion_tensor_shape.cc @@ -1,4 +1,5 @@ #include "local-execution/legion_tensor_shape.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index d4e0467cbf..5203991f25 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -51,7 +51,7 @@ CostDetails LocalCostEstimator::estimate_cost( for (ParallelTensorShape const &input : inputs) { TensorShape tensor_shape = get_piece_shape(input); tensor_guid_t tensor_id = - cg_builder.create_tensor(tensor_shape, CreateGrad::YES); + cg_builder.create_input(tensor_shape, CreateGrad::YES); GenericTensorAccessorW tensor_backing = allocator.allocate_tensor(tensor_shape); tensor_backing_map.insert({tensor_id, tensor_backing}); diff --git a/lib/local-execution/src/local_slots_backing.cc b/lib/local-execution/src/local_slots_backing.cc index 0ec9068c6a..ac35d63c0b 100644 --- a/lib/local-execution/src/local_slots_backing.cc +++ b/lib/local-execution/src/local_slots_backing.cc @@ -1,4 +1,6 @@ #include "local-execution/local_slots_backing.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph.h" #include "utils/containers/contains_key.h" #include "utils/overload.h" diff --git a/lib/local-execution/src/local_training_backing.cc b/lib/local-execution/src/local_training_backing.cc index a2ee06a95a..0fdf1761e3 100644 --- a/lib/local-execution/src/local_training_backing.cc +++ b/lib/local-execution/src/local_training_backing.cc @@ -1,5 +1,6 @@ #include "local-execution/local_training_backing.h" #include "local-execution/task_signature_impl.h" +#include "pcg/computation_graph.h" #include "utils/containers/reversed.h" #include "utils/exception.h" diff --git a/lib/local-execution/src/op_task_signature.cc b/lib/local-execution/src/op_task_signature.cc index 36a1dd708d..932b330453 100644 --- a/lib/local-execution/src/op_task_signature.cc +++ b/lib/local-execution/src/op_task_signature.cc @@ -1,4 +1,5 @@ #include "local-execution/op_task_signature.h" +#include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" #include "utils/fmt/unordered_set.h" diff --git a/lib/local-execution/src/ops/element_unary.cc b/lib/local-execution/src/ops/element_unary.cc index a52ebb8089..4ee609bd6c 100644 --- a/lib/local-execution/src/ops/element_unary.cc +++ b/lib/local-execution/src/ops/element_unary.cc @@ -1,6 +1,7 @@ #include "element_unary.h" #include "kernels/element_unary_kernels.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/hash-utils.h" namespace FlexFlow { diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc index 2bd0acc222..4c01df53e9 100644 --- a/lib/local-execution/test/src/test_local_cost_estimator.cc +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -2,10 +2,12 @@ #include "kernels/local_cuda_allocator.h" #include "kernels/managed_per_device_ff_handle.h" #include "local-execution/local_cost_estimator.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/computation_graph_builder.h" #include "test_utils.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_CUDA_TEST_SUITE) { TEST_CASE("Local Cost Estimator") { @@ -73,5 +75,3 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_slots_backing.cc b/lib/local-execution/test/src/test_local_slots_backing.cc index 542aa66087..1ec441fbca 100644 --- a/lib/local-execution/test/src/test_local_slots_backing.cc +++ b/lib/local-execution/test/src/test_local_slots_backing.cc @@ -1,15 +1,19 @@ -#include "doctest/doctest.h" #include "kernels/attention_kernels.h" #include "local-execution/local_cost_estimator.h" #include "local-execution/local_cpu_allocator.h" #include "local-execution/local_slots_backing.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/variant.h" +#include "test/utils/doctest/fmt/vector.h" #include "test_utils.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" +#include -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("LocalSlotsBacking -- Attention Op") { @@ -37,11 +41,11 @@ TEST_SUITE(FF_TEST_SUITE) { // build graph ComputationGraphBuilder cg_builder; tensor_guid_t query_guid = - cg_builder.create_tensor(query_shape, CreateGrad::YES); + cg_builder.create_input(query_shape, CreateGrad::YES); tensor_guid_t key_guid = - cg_builder.create_tensor(key_shape, CreateGrad::YES); + cg_builder.create_input(key_shape, CreateGrad::YES); tensor_guid_t value_guid = - cg_builder.create_tensor(value_shape, CreateGrad::YES); + cg_builder.create_input(value_shape, CreateGrad::YES); std::string layer_name = "attn1"; tensor_guid_t output_guid = @@ -269,5 +273,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_task_arg_accessor.cc b/lib/local-execution/test/src/test_local_task_arg_accessor.cc index 0637faaf1c..f52fccb1ed 100644 --- a/lib/local-execution/test/src/test_local_task_arg_accessor.cc +++ b/lib/local-execution/test/src/test_local_task_arg_accessor.cc @@ -4,7 +4,7 @@ #include "local-execution/task_signature_impl.h" #include "utils/fmt/variant.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("LocalTaskArgumentAccessor") { @@ -140,5 +140,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_task_registry.cc b/lib/local-execution/test/src/test_task_registry.cc index fa3b068425..e18b7ea2de 100644 --- a/lib/local-execution/test/src/test_task_registry.cc +++ b/lib/local-execution/test/src/test_task_registry.cc @@ -7,7 +7,7 @@ #include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Task Registry") { @@ -127,5 +127,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/models/CMakeLists.txt b/lib/models/CMakeLists.txt index 7dd7f48700..4f4b22ed47 100644 --- a/lib/models/CMakeLists.txt +++ b/lib/models/CMakeLists.txt @@ -11,6 +11,7 @@ ff_add_library( op-attrs utils pcg + rapidcheck ) -add_subdirectory(test) \ No newline at end of file +add_subdirectory(test) diff --git a/lib/models/include/models/split_test/split_test.h b/lib/models/include/models/split_test/split_test.h new file mode 100644 index 0000000000..b03e45b2d2 --- /dev/null +++ b/lib/models/include/models/split_test/split_test.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SPLIT_TEST_SPLIT_TEST_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SPLIT_TEST_SPLIT_TEST_H + +#include "pcg/computation_graph.dtg.h" + +namespace FlexFlow { + +/** + * @brief Get the computation graph of the old FlexFlow test model + * split_test + * + * @note This is a tiny model developed for testing the original Unity + * implementation. It is not a "real" model and has never been trained. + */ +ComputationGraph get_split_test_computation_graph(int batch_size); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/transformer.h b/lib/models/include/models/transformer/transformer.h similarity index 90% rename from lib/models/include/models/transformer.h rename to lib/models/include/models/transformer/transformer.h index e50fa37709..385100a4c9 100644 --- a/lib/models/include/models/transformer.h +++ b/lib/models/include/models/transformer/transformer.h @@ -1,7 +1,7 @@ -#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H -#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_TRANSFORMER_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_TRANSFORMER_H -#include "models/transformer_config.dtg.h" +#include "models/transformer/transformer_config.dtg.h" #include "pcg/computation_graph_builder.h" namespace FlexFlow { diff --git a/lib/models/include/models/transformer_config.struct.toml b/lib/models/include/models/transformer/transformer_config.struct.toml similarity index 100% rename from lib/models/include/models/transformer_config.struct.toml rename to lib/models/include/models/transformer/transformer_config.struct.toml diff --git a/lib/models/src/models/split_test/split_test.cc b/lib/models/src/models/split_test/split_test.cc new file mode 100644 index 0000000000..118f94ec06 --- /dev/null +++ b/lib/models/src/models/split_test/split_test.cc @@ -0,0 +1,39 @@ +#include "models/split_test/split_test.h" +#include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +ComputationGraph get_split_test_computation_graph(int batch_size) { + ComputationGraphBuilder cgb; + + int layer_dim1 = 256; + int layer_dim2 = 128; + int layer_dim3 = 64; + int layer_dim4 = 32; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(batch_size), + size_t_from_int(layer_dim1), + }}, + DataType::FLOAT, + }; + + tensor_guid_t t = cgb.create_input(input_shape, CreateGrad::YES); + t = cgb.dense(t, layer_dim2); + t = cgb.relu(t); + tensor_guid_t t1 = cgb.dense(t, layer_dim3); + tensor_guid_t t2 = cgb.dense(t, layer_dim3); + t = cgb.add(t1, t2); + t = cgb.relu(t); + t1 = cgb.dense(t, layer_dim4); + t2 = cgb.dense(t, layer_dim4); + t = cgb.add(t1, t2); + t = cgb.relu(t); + t = cgb.softmax(t); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/src/models/transformer.cc b/lib/models/src/models/transformer/transformer.cc similarity index 95% rename from lib/models/src/models/transformer.cc rename to lib/models/src/models/transformer/transformer.cc index 874cd85787..e179359940 100644 --- a/lib/models/src/models/transformer.cc +++ b/lib/models/src/models/transformer/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" namespace FlexFlow { @@ -100,7 +100,7 @@ tensor_guid_t assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention_normalized)); - tensor_guid_t mha = cgb.multihead_attention(input, + tensor_guid_t mha = cgb.multihead_attention(self_attention_normalized, encoder_output, encoder_output, config.num_features, @@ -149,11 +149,13 @@ ComputationGraph config.batch_size, config.sequence_length, config.num_features}}, DataType::FLOAT, }; - tensor_guid_t input = cgb.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES, "input"); + tensor_guid_t target = + cgb.create_input(input_shape, CreateGrad::YES, "target"); tensor_guid_t encoder_output = create_transformer_encoder(cgb, config, input); tensor_guid_t decoder_output = - create_transformer_decoder(cgb, config, input, encoder_output); + create_transformer_decoder(cgb, config, target, encoder_output); tensor_guid_t out_prob = cgb.softmax(cgb.dense(decoder_output, /*outDim=*/config.vocab_size, diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer.cc index 2133e9965b..20274c4151 100644 --- a/lib/models/test/src/models/transformer.cc +++ b/lib/models/test/src/models/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" #include @@ -12,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("num layers") { int result_num_layers = get_layers(result).size(); - int correct_num_layers = 317; + int correct_num_layers = 258; CHECK(result_num_layers == correct_num_layers); } } diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 6204b9ca49..5af00fb510 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -47,14 +47,7 @@ typename data_type_enum_to_class
::type cast_to(T t) { } template -using real_type = typename data_type_enum_to_class
::type; - -using DataTypeValue = std::variant, - real_type, - real_type, - real_type, - /* real_type, */ - real_type>; +using real_type_t = typename data_type_enum_to_class
::type; size_t size_of_datatype(DataType); diff --git a/lib/op-attrs/include/op-attrs/datatype_value.variant.toml b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml new file mode 100644 index 0000000000..3386e9d131 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "DataTypeValue" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +type = "float" + +[[values]] +type = "double" + +[[values]] +type = "int32_t" + +[[values]] +type = "int64_t" + +[[values]] +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index 6868ba083f..34d186e74e 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -3,8 +3,8 @@ #include "op-attrs/ff_dim.dtg.h" #include "utils/fmt/vector.h" -#include "utils/json.h" #include "utils/stack_vector.h" +#include namespace FlexFlow { @@ -202,11 +202,12 @@ FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { namespace nlohmann { template struct adl_serializer<::FlexFlow::DimOrdered> { - static ::FlexFlow::DimOrdered from_json(json const &j) { + static ::FlexFlow::DimOrdered from_json(nlohmann::json const &j) { return {j.template get>()}; } - static void to_json(json &j, ::FlexFlow::DimOrdered const &x) { + static void to_json(nlohmann::json &j, + ::FlexFlow::DimOrdered const &x) { j = std::vector{x.cbegin(), x.cend()}; } }; diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index f3dfe5d199..d39bac1bde 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" #include "utils/containers/subvec.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/optional.h" namespace FlexFlow { @@ -18,7 +18,7 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, }; return DimOrdered{ - subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; + subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; } template diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h index 3a31ea511d..ae6e552243 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" namespace FlexFlow { @@ -12,7 +12,7 @@ DimOrdered> transform(DimOrdered const &d, F f) { using Out = std::invoke_result_t; - return DimOrdered{vector_transform(as_vector(d), f)}; + return DimOrdered{vector_transform(vector_of(d), f)}; } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h index 54554afb81..023dcfc586 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" namespace FlexFlow { @@ -11,7 +11,7 @@ template DimOrdered> zip(DimOrdered const &lhs, DimOrdered const &rhs) { return DimOrdered>{ - zip(as_vector(lhs), as_vector(rhs))}; + zip(vector_of(lhs), vector_of(rhs))}; } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 0a5f057578..4fd7d49234 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -5,11 +5,14 @@ #include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "utils/record_formatter.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(BroadcastAttrs); +RecordFormatter as_dot(BroadcastAttrs const &); + tl::expected get_output_shape(BroadcastAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(BroadcastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml index 2fb385b64d..5bef144cd9 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -12,11 +12,12 @@ features = [ includes = [ "", "op-attrs/activation.dtg.h", - "utils/json.h", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] fields = [ diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml index 4b9c8a9f45..403bb87592 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -11,12 +11,14 @@ features = [ ] includes = [ - "utils/json.h", - "op-attrs/operator_type.h", + "op-attrs/operator_type.dtg.h", + "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml index 38d5a4371e..66d6f99253 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -17,6 +17,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml index eaa34cc496..0a35a6c5ec 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -13,11 +13,13 @@ includes = [ "op-attrs/datatype.dtg.h", "op-attrs/activation.dtg.h", "op-attrs/regularizer_attrs.dtg.h", - "utils/json.h", + "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index 108df58dce..14ee637f92 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -11,11 +11,6 @@ size_t &dim_at_idx(TensorShape &, ff_dim_t); size_t get_num_elements(TensorShape const &); size_t get_size_in_bytes(TensorShape const &); -bool tensor_shape_is_broadcastable_to(TensorShape const &curr, - TensorShape const &goal); -std::optional - get_broadcast_target_shape(std::unordered_set const &); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index 166416cbad..054930cebd 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -1,5 +1,8 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" +#include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/linear.h" +#include "utils/overload.h" namespace FlexFlow { @@ -8,4 +11,16 @@ OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { [](auto const &x) { return get_op_type(x); }); } +RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { + return attrs.visit(overload{ + [](LinearAttrs const &l) { return as_dot(l); }, + [](BroadcastAttrs const &a) { return as_dot(a); }, + [&](auto const &) { + RecordFormatter r; + r << fmt::to_string(get_op_type(attrs)); + return r; + }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index bd69864aff..aa3c95f551 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -1,8 +1,26 @@ #include "op-attrs/ops/broadcast.h" #include "op-attrs/tensor_dims.h" +#include "utils/record_formatter.h" namespace FlexFlow { +RecordFormatter as_dot(BroadcastAttrs const &attrs) { + RecordFormatter r; + + auto kv = [](std::string const &label, auto const &val) { + RecordFormatter rr; + rr << label << fmt::to_string(val); + return rr; + }; + + for (int i = 0; i < num_dims(attrs.target_dims); i++) { + r << kv(fmt::format("target_dims[{}]", i), + dim_at_idx(attrs.target_dims, ff_dim_t{i})); + } + + return r; +} + tl::expected get_output_shape(BroadcastAttrs const &attrs, TensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 73c0068826..4bce5449f4 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -4,9 +4,9 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/product.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -35,7 +35,7 @@ int total_replica_degree(ParallelTensorDims const &dims) { } int total_shard_degree(ParallelTensorDims const &dims) { - return product(transform(as_vector(dims.shard_dims), + return product(transform(vector_of(dims.shard_dims), [](ShardParallelDim const &d) { return d.degree; })); } diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index e716793a8f..ba7d6e8357 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -3,9 +3,9 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.dtg.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/integer_conversions.h" @@ -33,8 +33,8 @@ bool tensor_dims_is_broadcastable_to(TensorDims const &curr, return false; } - std::vector curr_dims = as_vector(curr.ff_ordered); - std::vector goal_dims = as_vector(goal.ff_ordered); + std::vector curr_dims = vector_of(curr.ff_ordered); + std::vector goal_dims = vector_of(goal.ff_ordered); for (auto const &[curr_dim, goal_dim] : zip(reversed(curr_dims), reversed(goal_dims))) { @@ -72,7 +72,7 @@ ParallelTensorDims DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees) { std::vector lifted = - transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), + transform(zip(vector_of(dims.ff_ordered), vector_of(shard_degrees)), [](std::pair const &p) { size_t size = p.first; int degree = p.second; diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index b604d442cb..07508e3065 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -27,35 +27,4 @@ size_t get_size_in_bytes(TensorShape const &s) { return get_num_elements(s) * size_of_datatype(s.data_type); } -bool tensor_shape_is_broadcastable_to(TensorShape const &curr, - TensorShape const &goal) { - return tensor_dims_is_broadcastable_to(curr.dims, goal.dims) && - curr.data_type == goal.data_type; -} - -std::optional - get_broadcast_target_shape(std::unordered_set const &shapes) { - std::unordered_set datatypes = - transform(shapes, [](TensorShape const &s) { return s.data_type; }); - - if (datatypes.size() != 1) { - return std::nullopt; - } - - std::unordered_set shapes_dims = - transform(shapes, [](TensorShape const &s) { return s.dims; }); - - std::optional maybe_result_dims = - get_broadcast_target_dims(shapes_dims); - std::optional result = - transform(maybe_result_dims, [&](TensorDims const &result_dims) { - return TensorShape{ - result_dims, - get_only(datatypes), - }; - }); - - return result; -} - } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/datatype.cc b/lib/op-attrs/test/src/datatype.cc index cc7e496c60..d45c156d59 100644 --- a/lib/op-attrs/test/src/datatype.cc +++ b/lib/op-attrs/test/src/datatype.cc @@ -1,6 +1,8 @@ #include "op-attrs/datatype.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("can_promote_datatype_from_to(DataType, DataType)") { diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/dim_ordered/slice.cc index 8640b077dc..8d5f247756 100644 --- a/lib/op-attrs/test/src/dim_ordered/slice.cc +++ b/lib/op-attrs/test/src/dim_ordered/slice.cc @@ -1,5 +1,7 @@ #include "op-attrs/dim_ordered/slice.h" -#include "test/utils/doctest.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE( diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc index d2c758a05f..180bc2a01f 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc @@ -1,5 +1,5 @@ #include "op-attrs/dim_ordered/enumerate.h" -#include "utils/fmt/map.h" +#include "test/utils/doctest/fmt/map.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc index 11e09dc43f..8e3d0f1b80 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc @@ -1,6 +1,6 @@ #include "op-attrs/dim_ordered/zip.h" #include "op-attrs/ff_dim.dtg.h" -#include "utils/fmt/pair.h" +#include "test/utils/doctest/fmt/pair.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc index 17a68ccbc8..7580de24e5 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/dropout.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index b9dd66df5d..cbcebdbce1 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -1,9 +1,9 @@ #include "op-attrs/ops/layer_norm.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include "utils/fmt/optional.h" +#include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc index f6a8da016f..65a74932cb 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/softmax.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include diff --git a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc index 25c7eb036f..60d87300c1 100644 --- a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc @@ -1,4 +1,5 @@ #include "op-attrs/tensor_dims.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc b/lib/op-attrs/test/src/op-attrs/tensor_shape.cc deleted file mode 100644 index bc715c183a..0000000000 --- a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc +++ /dev/null @@ -1,64 +0,0 @@ -#include "op-attrs/tensor_shape.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_broadcast_target_shape(std::unordered_set)") { - SUBCASE("target exists in inputs") { - DataType datatype = DataType::FLOAT; - - TensorShape s1 = TensorShape{ - TensorDims{FFOrdered{ - 1, - }}, - datatype, - }; - - TensorShape s2 = TensorShape{ - TensorDims{FFOrdered{10, 4, 3}}, - datatype, - }; - - TensorShape s3 = TensorShape{ - TensorDims{FFOrdered{ - 4, - 1, - }}, - datatype, - }; - - std::optional result = - get_broadcast_target_shape({s1, s2, s3}); - std::optional correct = s2; - - CHECK(result == correct); - } - - SUBCASE("datatypes don't match") { - TensorDims dims = TensorDims{FFOrdered{10, 4, 3}}; - - TensorShape s1 = TensorShape{ - dims, - DataType::FLOAT, - }; - - TensorShape s2 = TensorShape{ - dims, - DataType::DOUBLE, - }; - - std::optional result = get_broadcast_target_shape({s1, s2}); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - - SUBCASE("inputs is empty") { - std::optional result = get_broadcast_target_shape({}); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - } -} diff --git a/lib/op-attrs/test/src/ops/attention.cc b/lib/op-attrs/test/src/ops/attention.cc index ade219a6a9..2fb804ca8c 100644 --- a/lib/op-attrs/test/src/ops/attention.cc +++ b/lib/op-attrs/test/src/ops/attention.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " diff --git a/lib/op-attrs/test/src/ops/batch_matmul.cc b/lib/op-attrs/test/src/ops/batch_matmul.cc index 3ff02ccece..56a2e3fa52 100644 --- a/lib/op-attrs/test/src/ops/batch_matmul.cc +++ b/lib/op-attrs/test/src/ops/batch_matmul.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/batch_matmul.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/ops/cast.cc index 31030ca0f9..c7395316ad 100644 --- a/lib/op-attrs/test/src/ops/cast.cc +++ b/lib/op-attrs/test/src/ops/cast.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/cast.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Cast shape inference") { diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/ops/combine.cc index ac18bbc798..bf74a072e0 100644 --- a/lib/op-attrs/test/src/ops/combine.cc +++ b/lib/op-attrs/test/src/ops/combine.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/combine.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Combine shape inference") { diff --git a/lib/op-attrs/test/src/ops/element_binary.cc b/lib/op-attrs/test/src/ops/element_binary.cc index 0ed695eb89..b091833f10 100644 --- a/lib/op-attrs/test/src/ops/element_binary.cc +++ b/lib/op-attrs/test/src/ops/element_binary.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/element_binary.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("EWAdd shape inference") { diff --git a/lib/op-attrs/test/src/ops/element_unary.cc b/lib/op-attrs/test/src/ops/element_unary.cc index 4239782d55..94c382356e 100644 --- a/lib/op-attrs/test/src/ops/element_unary.cc +++ b/lib/op-attrs/test/src/ops/element_unary.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/element_unary.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ReLU shape inference") { diff --git a/lib/op-attrs/test/src/ops/embedding.cc b/lib/op-attrs/test/src/ops/embedding.cc index 9180f7055d..134737f6c0 100644 --- a/lib/op-attrs/test/src/ops/embedding.cc +++ b/lib/op-attrs/test/src/ops/embedding.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/embedding.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Sum embedding shape inference") { diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/ops/linear.cc index 0d23dc35df..f838ff4285 100644 --- a/lib/op-attrs/test/src/ops/linear.cc +++ b/lib/op-attrs/test/src/ops/linear.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Linear shape inference") { diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/ops/reduction.cc index 59ed5bb5ee..0d1c8bdf98 100644 --- a/lib/op-attrs/test/src/ops/reduction.cc +++ b/lib/op-attrs/test/src/ops/reduction.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/reduction.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Reduction shape inference") { diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/ops/repartition.cc index af28a6d471..8bc8205183 100644 --- a/lib/op-attrs/test/src/ops/repartition.cc +++ b/lib/op-attrs/test/src/ops/repartition.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/repartition.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Repartition shape inference") { diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/ops/replicate.cc index a0ec40cc14..60a1018479 100644 --- a/lib/op-attrs/test/src/ops/replicate.cc +++ b/lib/op-attrs/test/src/ops/replicate.cc @@ -1,5 +1,7 @@ #include "op-attrs/ops/replicate.h" -#include "test/utils/doctest.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Replicate shape inference") { diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc index f485b07b02..20825f5d73 100644 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -1,8 +1,8 @@ #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" -#include "utils/json.h" #include #include +#include #include using namespace ::FlexFlow; @@ -10,16 +10,16 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { BatchNormAttrs correct = BatchNormAttrs{true}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + BatchNormAttrs result = j.get(); CHECK(result == correct); } TEST_CASE("ComputationGraphAttrs to/from json") { ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{true}}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + ComputationGraphOpAttrs result = j.get(); CHECK(result == correct); } @@ -29,8 +29,8 @@ TEST_SUITE(FF_TEST_SUITE) { /*repartition_dim=*/ff_dim_t{1}, /*repartition_degree=*/4, }}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + PCGOperatorAttrs result = j.get(); CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/test_regularizer_attrs.cc index 35851463bb..6e172d1e8e 100644 --- a/lib/op-attrs/test/src/test_regularizer_attrs.cc +++ b/lib/op-attrs/test/src/test_regularizer_attrs.cc @@ -1,6 +1,8 @@ #include "op-attrs/regularizer_attrs.dtg.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index e1875ca694..e6eb182740 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -10,6 +10,7 @@ ff_add_library( DEPS op-attrs utils + rapidcheck ) add_subdirectory(ffi) diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 088139a0f3..499b26af89 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #include "pcg/computation_graph.dtg.h" +#include "pcg/computation_graph/computation_graph_edge.dtg.h" #include "pcg/computation_graph/layer_added_result.dtg.h" #include "pcg/layer_guid_t.dtg.h" #include "pcg/tensor_attrs.dtg.h" @@ -30,11 +31,24 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg, std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n); +std::unordered_set + get_subgraph_incoming_edges(ComputationGraph const &, + std::unordered_set const &); +std::unordered_set + get_subgraph_outgoing_edges(ComputationGraph const &, + std::unordered_set const &); +std::unordered_set + get_subgraph_successors(ComputationGraph const &, + std::unordered_set const &); + LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); layer_guid_t get_layer_by_name(ComputationGraph const &cg, std::string const &name); +std::string as_dot(ComputationGraph const &); +void debug_print_dot(ComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h new file mode 100644 index 0000000000..2a9a9ee04a --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_COMPUTATION_GRAPH_EDGE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_COMPUTATION_GRAPH_EDGE_H + +#include "pcg/computation_graph/computation_graph_edge.dtg.h" +#include "pcg/layer_guid_t.dtg.h" + +namespace FlexFlow { + +layer_guid_t get_computation_graph_edge_src_layer(ComputationGraphEdge const &); +layer_guid_t get_computation_graph_edge_dst_layer(ComputationGraphEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml new file mode 100644 index 0000000000..311c47d277 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ComputationGraphEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DataflowEdge" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index c641aed6a4..a35763cacc 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -159,9 +159,12 @@ struct ComputationGraphBuilder { std::optional activation = std::nullopt, bool use_bias = true, DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &projection_initializer = + std::nullopt, std::optional const &bias_initializer = std::nullopt, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt, + std::optional const &projection_name = std::nullopt, + std::optional const &bias_name = std::nullopt); // Add a cast layer tensor_guid_t cast(tensor_guid_t const &input, DataType dtype, @@ -225,12 +228,16 @@ struct ComputationGraphBuilder { bool add_zero_attn = false, std::optional initializer = std::nullopt, std::optional const &maybe_name = std::nullopt); - tensor_guid_t create_tensor(TensorShape const &, CreateGrad); + tensor_guid_t + create_input(TensorShape const &, + CreateGrad, + std::optional const &maybe_name = std::nullopt); tensor_guid_t create_weight( TensorShape const &, - bool create_grad = true, + CreateGrad create_grad = CreateGrad::YES, std::optional const &initializer = std::nullopt, - std::optional sync_type = std::nullopt); + std::optional sync_type = std::nullopt, + std::optional const &name = std::nullopt); std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; @@ -243,9 +250,8 @@ struct ComputationGraphBuilder { private: TensorShape get_shape(tensor_guid_t const &) const; - tensor_guid_t broadcast(tensor_guid_t const &, - TensorShape const &, - std::string const &); + tensor_guid_t + broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); @@ -259,13 +265,22 @@ struct ComputationGraphBuilder { std::vector const &weights, std::vector const &outputs); + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + TensorShape const &output); + tensor_guid_t add_layer(LayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, TensorShape const &output); - TensorShape get_broadcast_target_shape(std::vector const &); - TensorShape get_broadcast_target_shape(std::vector const &); + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output); + + TensorDims get_broadcast_target_dims(std::vector const &); + TensorDims get_broadcast_target_dims(std::vector const &); tensor_guid_t element_binary(OperatorType, diff --git a/lib/pcg/include/pcg/file_format/file_format.h b/lib/pcg/include/pcg/file_format/file_format.h deleted file mode 100644 index 823846754c..0000000000 --- a/lib/pcg/include/pcg/file_format/file_format.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_FILE_FORMAT_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_FILE_FORMAT_H - -#include "graphs.h" -#include "utils/json.h" - -namespace FlexFlow { - -enum class FileFormatVersion { - V1, - UNSTABLE, -}; - -json to_json(ComputationGraph const &, FileFormatVersion); -ComputationGraph from_json(json const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/keyed_variant.h b/lib/pcg/include/pcg/file_format/keyed_variant.h index 11044de12b..5e29d8c252 100644 --- a/lib/pcg/include/pcg/file_format/keyed_variant.h +++ b/lib/pcg/include/pcg/file_format/keyed_variant.h @@ -1,10 +1,11 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H -#include "utils/json.h" +#include "utils/json/is_jsonable.h" #include "utils/sequence.h" #include "utils/strong_typedef.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -29,9 +30,9 @@ struct KeyedVariant { }; struct ToJsonFunctor { - ToJsonFunctor(json &j) : j(j) {} + ToJsonFunctor(nlohmann::json &j) : j(j) {} - json &j; + nlohmann::json &j; template void operator()(T const &t) { @@ -42,20 +43,20 @@ struct ToJsonFunctor { }; template -void to_json(json &j, KeyedVariant const &v) { +void to_json(nlohmann::json &j, KeyedVariant const &v) { static_assert(is_jsonable::value, ""); K key = static_cast(v.value.index()); j["type"] = key; - json &jj = j["value"]; + nlohmann::json &jj = j["value"]; visit(ToJsonFunctor{j["value"]}, v.value); } template struct FromJsonFunctor { - FromJsonFunctor(json const &j, int idx) : j(j), idx(idx) {} + FromJsonFunctor(nlohmann::json const &j, int idx) : j(j), idx(idx) {} - json const &j; + nlohmann::json const &j; int idx; template @@ -68,31 +69,31 @@ struct FromJsonFunctor { template std::string get_json_name(T const &t) { - return json{t}.get(); + return nlohmann::json{t}.get(); } template struct FromJsonMoveOnlyFunctor { - FromJsonMoveOnlyFunctor(json const &j, Key const &key) : j(j) {} + FromJsonMoveOnlyFunctor(nlohmann::json const &j, Key const &key) : j(j) {} - json const &j; + nlohmann::json const &j; Key const &key; template Variant operator()(std::integral_constant const &) const { - return j.get::type>(); + return j.get::type>(); } }; template -Variant from_json_moveonly(json const &j, K const &key) { +Variant from_json_moveonly(nlohmann::json const &j, K const &key) { FromJsonMoveOnlyFunctor func(j); return seq_get(func, idx, seq_count_t::value>{}); } template typename std::enable_if::value>::type - from_json(json const &j, KeyedVariant &v) { + from_json(nlohmann::json const &j, KeyedVariant &v) { K key = j.at("type").get(); std::string key_string = j.at("type").get(); @@ -100,7 +101,7 @@ typename std::enable_if::value>::type } template -KeyedVariant keyed_variant_from_json(json const &j) { +KeyedVariant keyed_variant_from_json(nlohmann::json const &j) { K key = j.at("type").get(); return KeyedVariant{ diff --git a/lib/pcg/include/pcg/file_format/v1/data_type_value.h b/lib/pcg/include/pcg/file_format/v1/data_type_value.h index 6e4e5abc54..ec3910aab3 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type_value.h +++ b/lib/pcg/include/pcg/file_format/v1/data_type_value.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DATA_TYPE_H #include "utils/fp16.h" -#include "utils/json.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h deleted file mode 100644 index 702c79c2b6..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H - -#include "pcg/computation_graph.dtg.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" -#include "pcg/layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" -#include "pcg/tensor_attrs.dtg.h" -#include "utils/json.h" - -namespace FlexFlow { - -using V1ComputationGraph = V1LabelledDataflowGraph; -CHECK_IS_JSONABLE(V1ComputationGraph); -V1ComputationGraph to_v1(ComputationGraph const &); - -using V1ParallelComputationGraph = - V1LabelledDataflowGraph; -CHECK_IS_JSONABLE(V1ParallelComputationGraph); -V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml index d9aade739c..c332b6b41d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml @@ -1,9 +1,9 @@ namespace = "FlexFlow" name = "V1DataflowGraph" features = [ - # "eq", + "eq", # "ord", - # "hash", + "hash", "json", # "rapidcheck", "fmt", @@ -13,8 +13,13 @@ includes = [ "", "", "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", +] + +src_includes = [ "utils/fmt/vector.h", + "utils/hash/vector.h", "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h index 48203d73ae..fc9dfcef9a 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h @@ -13,8 +13,9 @@ namespace FlexFlow { template -V1LabelledDataflowGraph - to_v1(LabelledDataflowGraphView const &g) { +std::pair, bidict> + to_v1_including_node_numbering( + LabelledDataflowGraphView const &g) { bidict nodes = bidict_from_enumerating(get_nodes(g)); @@ -29,8 +30,17 @@ V1LabelledDataflowGraph [&](DataflowOutput const &o) { return g.at(o); }); }); - return V1LabelledDataflowGraph{ - node_labels, output_labels, unlabelled}; + return { + V1LabelledDataflowGraph{ + node_labels, output_labels, unlabelled}, + nodes, + }; +} + +template +V1LabelledDataflowGraph + to_v1(LabelledDataflowGraphView const &g) { + return to_v1_including_node_numbering(g).first; } } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml index fd8d4c39c4..b440d0f03d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml @@ -1,9 +1,9 @@ namespace = "FlexFlow" name = "V1LabelledDataflowGraph" features = [ - # "eq", + "eq", # "ord", - # "hash", + "hash", "json", # "rapidcheck", "fmt", @@ -20,6 +20,13 @@ includes = [ "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", ] +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + [[fields]] name = "node_labels" type = "std::unordered_map" @@ -31,4 +38,3 @@ type = "std::unordered_map>" [[fields]] name = "graph" type = "::FlexFlow::V1DataflowGraph" - diff --git a/lib/pcg/include/pcg/file_format/v1/v1.h b/lib/pcg/include/pcg/file_format/v1/v1.h deleted file mode 100644 index e2557af4f5..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_H - -#include "graphs.h" -#include "pcg/computation_graph.h" - -namespace FlexFlow {} - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h new file mode 100644 index 0000000000..5590d6999b --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_COMPUTATION_GRAPH_H + +#include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/v1_computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" + +namespace FlexFlow { + +V1ComputationGraph to_v1(ComputationGraph const &); + +std::pair> + to_v1_including_node_numbering(ComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml new file mode 100644 index 0000000000..0d7135ec74 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1ComputationGraph" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h new file mode 100644 index 0000000000..aceb59f5af --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_PARALLEL_COMPUTATION_GRAPH_H + +#include "pcg/file_format/v1/v1_parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..16be4a9561 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1ParallelComputationGraph" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml index 12917d0989..4e3c31bd36 100644 --- a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -10,12 +10,7 @@ features = [ ] includes = [ - "op-attrs/datatype.h", - "utils/json.h", -] - -src_includes = [ - "utils/fmt/variant.h", + "op-attrs/datatype_value.dtg.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml index d062f6cd78..8290795174 100644 --- a/lib/pcg/include/pcg/layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -13,11 +13,11 @@ includes = [ "op-attrs/computation_graph_op_attrs.dtg.h", "utils/stack_string.h", "", - "utils/json.h" ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 60cfc426cc..4d61f24d37 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -17,6 +17,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml index d9e6cf113b..323932fec6 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml @@ -19,6 +19,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml index c0b89cfc99..7f16e60914 100644 --- a/lib/pcg/include/pcg/tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/tensor_attrs.struct.toml @@ -19,6 +19,7 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", ] [[fields]] diff --git a/lib/pcg/src/file_format.cc b/lib/pcg/src/file_format.cc deleted file mode 100644 index bb01ac2dbf..0000000000 --- a/lib/pcg/src/file_format.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "pcg/file_format/v1/v1.h" - -namespace FlexFlow { - -/* void thing() { */ -/* static_assert(is_visitable::value, ""); */ - -/* json j; */ -/* auto g = j.get(); */ - -/* /1* IllBehaved v = j.get(); *1/ */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc deleted file mode 100644 index de8d5dddb4..0000000000 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include "pcg/file_format/v1/graphs.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" -#include "utils/graph/algorithms.h" -#include "utils/integer_conversions.h" - -namespace FlexFlow { - -V1ComputationGraph to_v1(ComputationGraph const &g) { - return to_v1(g.raw_graph); -} - -V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { - return to_v1(g.raw_graph); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index deaa440ef8..cf4b1496cf 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,11 +1,18 @@ #include "pcg/computation_graph.h" +#include "op-attrs/computation_graph_op_attrs.h" #include "utils/containers/get_only.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" +#include "utils/record_formatter.h" namespace FlexFlow { @@ -20,6 +27,23 @@ std::unordered_set get_layers(ComputationGraph const &cg) { [&](Node const &n) { return layer_guid_t{n}; }); } +LayerAddedResult add_layer(ComputationGraph &computation_graph, + LayerAttrs const &attrs, + std::vector const &inputs, + std::vector const &outputs) { + std::vector raw_inputs = transform( + inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + + NodeAddedResult added = + computation_graph.raw_graph.add_node(attrs, raw_inputs, outputs); + + return LayerAddedResult{ + layer_guid_t{added.node}, + transform(added.outputs, + [](DataflowOutput const &o) { return tensor_guid_t{o}; }), + }; +} + TensorAttrs get_tensor_attrs(ComputationGraph const &cg, tensor_guid_t const &t) { return cg.raw_graph.at(t.raw_graph_output); @@ -39,8 +63,7 @@ std::vector topological_ordering(ComputationGraph const &cg) { std::vector reverse_topological_ordering(ComputationGraph const &cg) { - std::vector layers = - reversed>(get_topological_ordering(cg.raw_graph)); + std::vector layers = reversed(get_topological_ordering(cg.raw_graph)); return transform( layers, [&](Node const &e) -> layer_guid_t { return layer_guid_t{e}; }); } @@ -57,6 +80,47 @@ std::vector get_incoming_tensors(ComputationGraph const &cg, [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } +std::unordered_set get_subgraph_incoming_edges( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_incoming_edges = + get_subgraph_incoming_edges(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_incoming_edges, [](DataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_outgoing_edges( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_outgoing_edges = + get_subgraph_outgoing_edges(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_outgoing_edges, [](DataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_successors( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_successors = + get_subgraph_successors(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_successors, + [](Node const &n) { return layer_guid_t{n}; }); +} + LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n) { return cg.raw_graph.at(n.raw_node); } @@ -70,4 +134,40 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, return get_only(found); } +std::string as_dot(ComputationGraph const &cg) { + std::function get_node_label = + [](LayerAttrs const &a) -> std::string { + RecordFormatter r = as_dot(a.attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + std::function get_input_label = + [](TensorAttrs const &a) -> std::string { + RecordFormatter r; + + r << fmt::to_string(a.shape); + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + return as_dot(view_as_labelled_open_dataflow_graph(cg.raw_graph), + get_node_label, + get_input_label); +} + +void debug_print_dot(ComputationGraph const &cg) { + std::cout << as_dot(cg) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc b/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc new file mode 100644 index 0000000000..0efa0620c4 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc @@ -0,0 +1,15 @@ +#include "pcg/computation_graph/computation_graph_edge.h" + +namespace FlexFlow { + +layer_guid_t + get_computation_graph_edge_src_layer(ComputationGraphEdge const &e) { + return layer_guid_t{e.raw_edge.src.node}; +} + +layer_guid_t + get_computation_graph_edge_dst_layer(ComputationGraphEdge const &e) { + return layer_guid_t{e.raw_edge.dst.node}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 3f2feaf619..e0b6935a6d 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -15,6 +15,7 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/tensor_dims.h" #include "pcg/computation_graph.h" #include "utils/containers/any_of.h" #include "utils/containers/concat_vectors.h" @@ -26,6 +27,16 @@ namespace FlexFlow { +static TensorAttrs make_weight_attrs( + TensorShape const &shape, + std::optional const &initializer_attrs) { + return TensorAttrs{shape, initializer_attrs, std::nullopt, CreateGrad::YES}; +} + +static TensorAttrs make_output_attrs(TensorShape const &shape) { + return TensorAttrs{shape, std::nullopt, std::nullopt, CreateGrad::YES}; +} + ComputationGraphBuilder::ComputationGraphBuilder() : computation_graph(make_empty_computation_graph()) {} @@ -33,13 +44,31 @@ TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { return get_tensor_attrs(this->computation_graph, t).shape; } -tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, - CreateGrad create_grad) { +tensor_guid_t ComputationGraphBuilder::create_input( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &maybe_name) { TensorAttrs tensor_attrs = TensorAttrs{shape, std::nullopt, std::nullopt, create_grad}; LayerAttrs layer_attrs = LayerAttrs{ ComputationGraphOpAttrs{InputAttrs{}}, - std::nullopt, + maybe_name, + }; + + return this->add_layer(layer_attrs, {}, {}, tensor_attrs); +} + +tensor_guid_t ComputationGraphBuilder::create_weight( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &initializer, + std::optional param_sync, + std::optional const &maybe_name) { + TensorAttrs tensor_attrs = + TensorAttrs{shape, initializer, param_sync, create_grad}; + LayerAttrs layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{InputAttrs{}}, + maybe_name, }; return this->add_layer(layer_attrs, {}, {}, tensor_attrs); @@ -98,9 +127,31 @@ std::vector ComputationGraphBuilder::add_layer( std::vector const &weights, std::vector const &outputs) { return this->add_layer( - layer, inputs, weights, transform(outputs, [](TensorShape const &s) { - return TensorAttrs{s, std::nullopt, std::nullopt, CreateGrad::YES}; - })); + layer, inputs, weights, transform(outputs, make_output_attrs)); +} + +tensor_guid_t ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output_shape) { + + TensorAttrs output_attrs = make_output_attrs(output_shape); + LayerAddedResult added = + ::FlexFlow::add_layer(this->computation_graph, + layer, + concat_vectors(inputs, weights), + {output_attrs}); + return get_only(added.outputs); +} + +tensor_guid_t + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + TensorShape const &output_shape) { + + std::vector weights = {}; + return this->add_layer(layer, inputs, weights, output_shape); } tensor_guid_t @@ -129,25 +180,28 @@ tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, } } -tensor_guid_t - ComputationGraphBuilder::broadcast(tensor_guid_t const &input, - TensorShape const &target_shape, - std::string const &name) { +tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &input, + TensorDims const &target_dims, + std::string const &name) { TensorShape input_shape = this->get_shape(input); - if (!tensor_shape_is_broadcastable_to(input_shape, target_shape)) { + if (input_shape.dims == target_dims) { + return input; + } + + if (!tensor_dims_is_broadcastable_to(input_shape.dims, target_dims)) { throw mk_runtime_error(fmt::format( - "Cannot broadcast input tensor of shape {} to target shape {}", - input_shape, - target_shape)); + "Cannot broadcast input tensor of dims {} to target dims {}", + input_shape.dims, + target_dims)); } - BroadcastAttrs attrs = BroadcastAttrs{target_shape.dims}; + BroadcastAttrs attrs = BroadcastAttrs{target_dims}; LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t @@ -184,7 +238,7 @@ tensor_guid_t ComputationGraphBuilder::element_unary( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::element_binary( @@ -194,18 +248,18 @@ tensor_guid_t ComputationGraphBuilder::element_binary( std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(op_type)); - TensorShape compute_shape = this->get_broadcast_target_shape({lhs, rhs}); + TensorDims compute_dims = this->get_broadcast_target_dims({lhs, rhs}); DataType compute_type = std::max(this->get_shape(lhs).data_type, this->get_shape(rhs).data_type); tensor_guid_t lhs_input = this->as_type( this->broadcast( - lhs, compute_shape, fmt::format("{}_inputl_broadcast", name)), + lhs, compute_dims, fmt::format("{}_inputl_broadcast", name)), compute_type, name + "_inputl_cast"); tensor_guid_t rhs_input = this->as_type( this->broadcast( - rhs, compute_shape, fmt::format("{}_inputr_broadcast", name)), + rhs, compute_dims, fmt::format("{}_inputr_broadcast", name)), compute_type, name + "_inputr_cast"); @@ -217,7 +271,7 @@ tensor_guid_t ComputationGraphBuilder::element_binary( TensorShape output_shape = throw_if_unexpected(get_output_shape( attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); - return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); + return this->add_layer(layer, {lhs_input, rhs_input}, output_shape); } tensor_guid_t @@ -359,12 +413,6 @@ tensor_guid_t return this->element_unary(OperatorType::ELU, input, std::nullopt, name); } -static TensorAttrs make_weight_attrs( - TensorShape const &shape, - std::optional const &initializer_attrs) { - return TensorAttrs{shape, initializer_attrs, std::nullopt, CreateGrad::YES}; -} - tensor_guid_t ComputationGraphBuilder::conv2d( tensor_guid_t const &x, int outChannels, @@ -431,7 +479,7 @@ tensor_guid_t ComputationGraphBuilder::dropout( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::embedding( @@ -483,7 +531,7 @@ tensor_guid_t ComputationGraphBuilder::gather( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } /* std::vector @@ -531,7 +579,7 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -581,26 +629,26 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( output_shape); } -TensorShape ComputationGraphBuilder::get_broadcast_target_shape( +TensorDims ComputationGraphBuilder::get_broadcast_target_dims( std::vector const &inputs) { - std::vector input_shapes = transform( - inputs, [&](tensor_guid_t const &t) { return this->get_shape(t); }); + std::vector inputs_dims = transform( + inputs, [&](tensor_guid_t const &t) { return this->get_shape(t).dims; }); - return this->get_broadcast_target_shape(input_shapes); + return this->get_broadcast_target_dims(inputs_dims); } -TensorShape ComputationGraphBuilder::get_broadcast_target_shape( - std::vector const &input_shapes) { - std::optional maybe_result = - ::FlexFlow::get_broadcast_target_shape(unordered_set_of(input_shapes)); +TensorDims ComputationGraphBuilder::get_broadcast_target_dims( + std::vector const &inputs_dims) { + std::optional maybe_result = + ::FlexFlow::get_broadcast_target_dims(unordered_set_of(inputs_dims)); if (maybe_result.has_value()) { return maybe_result.value(); } else { throw mk_runtime_error(fmt::format( - "ComputationGraphBuilder::get_broadcast_target_shape failed to find " - "target tensor shape for input tensor shapes {}", - input_shapes)); + "ComputationGraphBuilder::get_broadcast_target_dims failed to find " + "target tensor dims for input tensor dims {}", + inputs_dims)); } } @@ -610,9 +658,11 @@ tensor_guid_t ComputationGraphBuilder::dense( std::optional activation, bool use_bias, DataType data_type, - std::optional const &kernel_initializer, + std::optional const &projection_initializer, std::optional const &bias_initializer, - std::optional const &maybe_name) { + std::optional const &maybe_name, + std::optional const &projection_name, + std::optional const &bias_name) { LinearAttrs attrs = LinearAttrs{outDim, use_bias, data_type, activation, std::nullopt}; @@ -623,15 +673,30 @@ tensor_guid_t ComputationGraphBuilder::dense( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - std::vector weights; - TensorShape kernel_shape = + std::vector weights; + + TensorShape projection_shape = throw_if_unexpected(get_kernel_shape(attrs, this->get_shape(input))); - weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + + tensor_guid_t projection_weights = + this->create_weight(projection_shape, + CreateGrad::YES, + projection_initializer, + /*sync_type=*/std::nullopt, + projection_name); + + weights.push_back(projection_weights); if (use_bias) { TensorShape bias_shape = throw_if_unexpected(get_bias_shape(attrs, this->get_shape(input))); - weights.push_back(make_weight_attrs(bias_shape, bias_initializer)); + + tensor_guid_t bias_weights = this->create_weight(bias_shape, + CreateGrad::YES, + bias_initializer, + /*sync_type=*/std::nullopt, + bias_name); + weights.push_back(bias_weights); } return this->add_layer(layer, {input}, weights, output_shape); @@ -677,13 +742,13 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( TensorShape gamma_shape = throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); InitializerAttrs gamma_initializer = - InitializerAttrs{ConstantInitializerAttrs{float{1}}}; + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); TensorShape beta_shape = throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); InitializerAttrs beta_initializer = - InitializerAttrs{ConstantInitializerAttrs{float{0}}}; + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } @@ -716,7 +781,7 @@ tensor_guid_t ComputationGraphBuilder::softmax( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc new file mode 100644 index 0000000000..975e92dfb7 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc @@ -0,0 +1,24 @@ +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" + +namespace FlexFlow { + +V1ComputationGraph to_v1(ComputationGraph const &g) { + return V1ComputationGraph{ + to_v1(g.raw_graph), + }; +} + +std::pair> + to_v1_including_node_numbering(ComputationGraph const &cg) { + std::pair, bidict> + raw = + to_v1_including_node_numbering(cg.raw_graph); + V1ComputationGraph v1_cg = V1ComputationGraph{raw.first}; + bidict v1_node_ids = + map_values(raw.second, [](Node const &n) { return layer_guid_t{n}; }); + + return {v1_cg, v1_node_ids}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc new file mode 100644 index 0000000000..9da58fcf6e --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -0,0 +1,12 @@ +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" + +namespace FlexFlow { + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { + return V1ParallelComputationGraph{ + to_v1(g.raw_graph), + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc new file mode 100644 index 0000000000..8336d81bb4 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc @@ -0,0 +1,30 @@ +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1ComputationGraph") { + ComputationGraph cg = [] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 12, + 16, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + tensor_guid_t mm_output = b.dense(input, 8); + tensor_guid_t relu_output = b.relu(mm_output); + + return b.computation_graph; + }(); + + V1ComputationGraph v1_cg = to_v1(cg); + nlohmann::json j = v1_cg; + } +} diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc new file mode 100644 index 0000000000..8ce25c4bc5 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -0,0 +1,36 @@ +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1ParallelComputationGraph") { + ParallelComputationGraph pcg = [] { + ParallelComputationGraphBuilder b; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t mm_output = b.dense(input, 8); + parallel_tensor_guid_t relu_output = b.relu(mm_output); + + return b.pcg; + }(); + + V1ParallelComputationGraph v1_pcg = to_v1(pcg); + nlohmann::json j = v1_pcg; + } +} diff --git a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc index 0b75e3ae1a..703c129da4 100644 --- a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc +++ b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc @@ -1,6 +1,8 @@ #include "pcg/initializers/uniform_initializer_attrs.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 440f735e80..f46f267859 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -3,7 +3,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" -#include "test/utils/doctest.h" #include "utils/containers/count.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" @@ -12,6 +11,9 @@ #include "utils/containers/values.h" #include "utils/containers/without_nullopts.h" #include "utils/hash/pair.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ParallelComputationGraphBuilder::add") { @@ -227,7 +229,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(num_replicate_attrs == 2); parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform( - as_vector(items(layers)), + vector_of(items(layers)), [](std::pair const &kv) -> std::optional { if (get_op_type(kv.second) == OperatorType::CONV2D) { diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc index 936c2de00d..ff169d8312 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -15,7 +15,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - tensor_guid_t input = b.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); tensor_guid_t output = b.conv2d(input, /*outChannels=*/5, /*kernelH=*/3, diff --git a/lib/pcg/test/src/test_machine_view.cc b/lib/pcg/test/src/test_machine_view.cc index 70fe958d8c..25c6e21b87 100644 --- a/lib/pcg/test/src/test_machine_view.cc +++ b/lib/pcg/test/src/test_machine_view.cc @@ -1,7 +1,7 @@ #include "pcg/machine_view.h" #include "pcg/strided_rectangle.h" #include "pcg/strided_rectangle_side.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/pcg/test/src/test_strided_rectangle.cc b/lib/pcg/test/src/test_strided_rectangle.cc index 2fe3005b15..ac6af9fa19 100644 --- a/lib/pcg/test/src/test_strided_rectangle.cc +++ b/lib/pcg/test/src/test_strided_rectangle.cc @@ -1,6 +1,6 @@ #include "pcg/strided_rectangle.h" #include "pcg/strided_rectangle_side.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/runtime/src/accessor.cc b/lib/runtime/src/accessor.cc index 44ad8ab40d..84573fb4aa 100644 --- a/lib/runtime/src/accessor.cc +++ b/lib/runtime/src/accessor.cc @@ -129,7 +129,7 @@ struct GetTensorPointerWOFunctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void *)helperGetTensorPointerWO>( + return (void *)helperGetTensorPointerWO>( region, req, fid, ctx, runtime); } }; @@ -141,7 +141,7 @@ struct GetTensorPointerROFunctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void const *)helperGetTensorPointerRO>( + return (void const *)helperGetTensorPointerRO>( region, req, fid, ctx, runtime); } }; @@ -153,7 +153,7 @@ struct GetTensorPointerRWFUnctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void *)helperGetTensorPointerRW>( + return (void *)helperGetTensorPointerRW>( region, req, fid, ctx, runtime); } }; diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index 7df65ef361..ad36f1bc4b 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -25,6 +25,7 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", "utils/fmt/vector.h", "utils/hash/vector.h", ] diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 00032045c0..2d76352ccf 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -45,10 +45,6 @@ std::unordered_set get_subgraph_outgoing_edges( SubParallelComputationGraph const &, std::unordered_set const &); -std::unordered_set get_subgraph_incoming_edges( - SubParallelComputationGraph const &, - std::unordered_set const &); - std::unordered_set get_parallel_tensor_uses(SubParallelComputationGraph const &, open_parallel_tensor_guid_t const &); diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 26f8ff5062..a18737085a 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -1,6 +1,6 @@ #include "substitutions/operator_pattern/get_attribute.h" #include "op-attrs/get_op_type.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { @@ -364,7 +364,7 @@ std::optional get_attribute(TransposeAttrs const &p, case OperatorAttributeKey::OP_TYPE: return get_op_type(p); case OperatorAttributeKey::PERMUTATION: - return as_vector(p.perm); + return vector_of(p.perm); default: return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index 05f21247c7..286bc69b84 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -1,7 +1,7 @@ #include "substitutions/tensor_pattern/get_attribute.h" #include "op-attrs/parallel_tensor_dims.h" -#include "utils/containers/as_vector.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -11,13 +11,13 @@ TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector sizes = - transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + transform(vector_of(ff_ordered_shard_dims(attrs.shape.dims)), [](ShardParallelDim const &d) { return d.size; }); return TensorAttributeValue{sizes}; } case TensorAttributeKey::DIM_DEGREES: { std::vector degrees = transform( - as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + vector_of(ff_ordered_shard_dims(attrs.shape.dims)), [](ShardParallelDim const &d) { return size_t_from_int(d.degree); }); return TensorAttributeValue{degrees}; } diff --git a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc index 70e960bc73..95b61e0ef4 100644 --- a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc @@ -1,4 +1,5 @@ #include "substitutions/operator_pattern/get_attribute.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 6922798a97..4f56a76d0d 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -5,9 +5,9 @@ #include "substitutions/operator_pattern/operator_attribute_constraint.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc index 6621145d39..e4d763d9c3 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc @@ -1,10 +1,10 @@ #include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/pattern_value.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 3475c10235..e0805dbfd4 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,8 +1,8 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index 9478195523..aeedd65f82 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -1,9 +1,6 @@ -#include "doctest/doctest.h" -#include "rapidcheck.h" #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/match_additional_criterion.h" #include "substitutions/unlabelled/pattern_matching.h" -#include "test/utils/all.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -13,6 +10,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" #include "utils/overload.h" +#include using namespace FlexFlow; diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index ae5e120fad..a0d77b9f76 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -13,7 +13,6 @@ ff_add_library( fmt json cuda - doctest ) add_subdirectory(ffi) diff --git a/lib/utils/include/utils/cli/cli_argument_key.variant.toml b/lib/utils/include/utils/cli/cli_argument_key.variant.toml new file mode 100644 index 0000000000..be118160ce --- /dev/null +++ b/lib/utils/include/utils/cli/cli_argument_key.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "CLIArgumentKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_positional_argument_key.dtg.h", + "utils/cli/cli_flag_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::CLIPositionalArgumentKey" + +[[values]] +type = "::FlexFlow::CLIFlagKey" diff --git a/lib/utils/include/utils/cli/cli_flag_key.struct.toml b/lib/utils/include/utils/cli/cli_flag_key.struct.toml new file mode 100644 index 0000000000..790a752911 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_key.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "CLIFlagKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/cli/cli_flag_spec.struct.toml b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml new file mode 100644 index 0000000000..66a47de067 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "CLIFlagSpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "long_flag" +type = "std::string" + +[[fields]] +name = "short_flag" +type = "std::optional" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_get_help_message.h b/lib/utils/include/utils/cli/cli_get_help_message.h new file mode 100644 index 0000000000..d51579a8e2 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_get_help_message.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_GET_HELP_MESSAGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_GET_HELP_MESSAGE_H + +#include "utils/cli/cli_spec.dtg.h" + +namespace FlexFlow { + +std::string cli_get_help_message(std::string const &program_name, + CLISpec const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse.h b/lib/utils/include/utils/cli/cli_parse.h new file mode 100644 index 0000000000..3c91a8423b --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_H + +#include "utils/cli/cli_parse_result.dtg.h" +#include "utils/cli/cli_spec.dtg.h" +#include + +namespace FlexFlow { + +tl::expected cli_parse_flag(CLISpec const &cli, + std::string const &arg); +tl::expected + cli_parse(CLISpec const &, std::vector const &); +tl::expected + cli_parse(CLISpec const &, int argc, char const *const *argv); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse_result.h b/lib/utils/include/utils/cli/cli_parse_result.h new file mode 100644 index 0000000000..155caac7ae --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_RESULT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_RESULT_H + +#include "utils/cli/cli_argument_key.dtg.h" +#include "utils/cli/cli_parse_result.dtg.h" + +namespace FlexFlow { + +bool cli_get_flag(CLIParseResult const &, CLIArgumentKey const &); +std::string cli_get_argument(CLIParseResult const &, CLIArgumentKey const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse_result.struct.toml b/lib/utils/include/utils/cli/cli_parse_result.struct.toml new file mode 100644 index 0000000000..b63da7be14 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "CLIParseResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "utils/cli/cli_flag_key.dtg.h", + "utils/cli/cli_positional_argument_key.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "flags" +type = "std::unordered_map<::FlexFlow::CLIFlagKey, bool>" + +[[fields]] +name = "positional_arguments" +type = "std::unordered_map<::FlexFlow::CLIPositionalArgumentKey, std::string>" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml new file mode 100644 index 0000000000..d571d0deb3 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml new file mode 100644 index 0000000000..b1e74701ee --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentSpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "name" +type = "std::string" + +[[fields]] +name = "choices" +type = "std::optional>" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_spec.h b/lib/utils/include/utils/cli/cli_spec.h new file mode 100644 index 0000000000..2c0df08c55 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_SPEC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_SPEC_H + +#include "utils/cli/cli_argument_key.dtg.h" +#include "utils/cli/cli_flag_spec.dtg.h" +#include "utils/cli/cli_spec.dtg.h" +#include + +namespace FlexFlow { + +CLISpec empty_cli_spec(); +std::vector cli_get_flag_keys(CLISpec const &); +CLIArgumentKey cli_add_help_flag(CLISpec &); +CLIArgumentKey cli_add_flag(CLISpec &, CLIFlagSpec const &); +CLIArgumentKey cli_add_positional_argument(CLISpec &, + CLIPositionalArgumentSpec const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_spec.struct.toml b/lib/utils/include/utils/cli/cli_spec.struct.toml new file mode 100644 index 0000000000..9f64f62c15 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "CLISpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_flag_spec.dtg.h", + "utils/cli/cli_positional_argument_spec.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "flags" +type = "std::vector<::FlexFlow::CLIFlagSpec>" + +[[fields]] +name = "positional_arguments" +type = "std::vector<::FlexFlow::CLIPositionalArgumentSpec>" diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 937ed51af2..20ab6ce440 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -74,9 +74,6 @@ bool are_all_same(C const &c); template std::function compare_by(F const &f); -template -typename C::value_type maximum(C const &v); - template T reversed(T const &t); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 7c0490fa2a..f60ef77cda 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -179,11 +179,6 @@ std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; } -template -typename C::value_type maximum(C const &v) { - return *std::max_element(v.begin(), v.end()); -} - template std::vector value_all(std::vector> const &v) { return transform(v, [](std::optional const &element) { diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h index 11ee8d2352..700106ea3f 100644 --- a/lib/utils/include/utils/containers/enumerate_vector.h +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_VECTOR_H #include -#include #include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/foldl1.h b/lib/utils/include/utils/containers/foldl1.h new file mode 100644 index 0000000000..f542f8cf00 --- /dev/null +++ b/lib/utils/include/utils/containers/foldl1.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldl1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldl1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.cbegin(); + T result = *it; + it++; + + for (; it != vec.cend(); it++) { + result = f(result, *it); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/foldr1.h b/lib/utils/include/utils/containers/foldr1.h new file mode 100644 index 0000000000..4a7e8e098c --- /dev/null +++ b/lib/utils/include/utils/containers/foldr1.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR1_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldr1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldr1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.crbegin(); + T result = *it; + it++; + for (; it != vec.crend(); it++) { + result = f(result, *it); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/generate_map.h b/lib/utils/include/utils/containers/generate_map.h index 1afa534a19..53b2a590c5 100644 --- a/lib/utils/include/utils/containers/generate_map.h +++ b/lib/utils/include/utils/containers/generate_map.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_MAP_H -#include "utils/containers/as_vector.h" #include "utils/containers/get_element_type.h" +#include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" #include "utils/type_traits_core.h" #include @@ -17,7 +17,7 @@ std::unordered_map generate_map(C const &c, F const &f) { static_assert(is_hashable_v, "Key type should be hashable (but is not)"); auto transformed = - vector_transform(as_vector(c), [&](K const &k) -> std::pair { + vector_transform(vector_of(c), [&](K const &k) -> std::pair { return {k, f(k)}; }); return {transformed.cbegin(), transformed.cend()}; diff --git a/lib/utils/include/utils/containers/get_first.h b/lib/utils/include/utils/containers/get_first.h index ce2a483401..a616c44c20 100644 --- a/lib/utils/include/utils/containers/get_first.h +++ b/lib/utils/include/utils/containers/get_first.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H +#include #include namespace FlexFlow { @@ -10,6 +11,11 @@ T get_first(std::unordered_set const &s) { return *s.cbegin(); } +template +T get_first(std::set const &s) { + return *s.cbegin(); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h new file mode 100644 index 0000000000..634bb61bc1 --- /dev/null +++ b/lib/utils/include/utils/containers/maximum.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H + +#include +#include + +namespace FlexFlow { + +template +std::optional maximum(C const &v) { + if (v.empty()) { + return std::nullopt; + } + + return *std::max_element(std::cbegin(v), std::cend(v)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/multiset_union.h b/lib/utils/include/utils/containers/multiset_union.h new file mode 100644 index 0000000000..6f2b2a7889 --- /dev/null +++ b/lib/utils/include/utils/containers/multiset_union.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_UNION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_UNION_H + +#include +#include + +namespace FlexFlow { + +template +std::unordered_multiset + multiset_union(std::unordered_multiset const &lhs, + std::unordered_multiset const &rhs) { + std::unordered_multiset result = lhs; + + for (T const &t : rhs) { + result.insert(t); + } + + return result; +} + +template +std::multiset multiset_union(std::multiset const &lhs, + std::multiset const &rhs) { + std::multiset result = lhs; + + for (T const &t : rhs) { + result.insert(t); + } + + return result; +} + +template +std::unordered_multiset multiset_union(C const &c) { + std::unordered_multiset result; + for (auto const &s : c) { + for (T const &element : s) { + result.insert(element); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_no_duplicates.h b/lib/utils/include/utils/containers/require_no_duplicates.h new file mode 100644 index 0000000000..0cbe361bdd --- /dev/null +++ b/lib/utils/include/utils/containers/require_no_duplicates.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_NO_DUPLICATES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_NO_DUPLICATES_H + +#include "utils/exception.h" +#include "utils/fmt/multiset.h" +#include "utils/fmt/unordered_multiset.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set + require_no_duplicates(std::unordered_multiset const &s) { + std::unordered_set result{s.cbegin(), s.cend()}; + + if (result.size() != s.size()) { + throw mk_runtime_error(fmt::format( + "require_no_duplicates encountered duplicate in set {}", s)); + } + + return result; +} + +template +std::set require_no_duplicates(std::multiset const &s) { + std::set result{s.cbegin(), s.cend()}; + + if (result.size() != s.size()) { + throw mk_runtime_error(fmt::format( + "require_no_duplicates encountered duplicate in set {}", s)); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/reversed.h b/lib/utils/include/utils/containers/reversed.h index 621eee9519..902b247469 100644 --- a/lib/utils/include/utils/containers/reversed.h +++ b/lib/utils/include/utils/containers/reversed.h @@ -1,15 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_H +#include + namespace FlexFlow { template -T reversed(T const &t) { - T r; - for (auto i = t.cend() - 1; i >= t.begin(); i--) { - r.push_back(*i); - } - return r; +std::vector reversed(std::vector const &t) { + std::vector result(std::crbegin(t), std::crend(t)); + return result; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/set_minus.h b/lib/utils/include/utils/containers/set_minus.h index 6efa2f0a84..fdd1f11995 100644 --- a/lib/utils/include/utils/containers/set_minus.h +++ b/lib/utils/include/utils/containers/set_minus.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H +#include #include namespace FlexFlow { @@ -15,6 +16,15 @@ std::unordered_set set_minus(std::unordered_set const &l, return result; } +template +std::set set_minus(std::set const &l, std::set const &r) { + std::set result = l; + for (T const &t : r) { + result.erase(t); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/set_of.h b/lib/utils/include/utils/containers/set_of.h new file mode 100644 index 0000000000..14658209aa --- /dev/null +++ b/lib/utils/include/utils/containers/set_of.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_OF_H + +#include + +namespace FlexFlow { + +template +std::set set_of(C const &c) { + std::set result; + for (T const &t : c) { + result.insert(t); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/to_uppercase.h b/lib/utils/include/utils/containers/to_uppercase.h new file mode 100644 index 0000000000..a2dc7786f9 --- /dev/null +++ b/lib/utils/include/utils/containers/to_uppercase.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TO_UPPERCASE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TO_UPPERCASE_H + +#include + +namespace FlexFlow { + +std::string to_uppercase(std::string const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/as_vector.h b/lib/utils/include/utils/containers/vector_of.h similarity index 54% rename from lib/utils/include/utils/containers/as_vector.h rename to lib/utils/include/utils/containers/vector_of.h index fafa1dc799..7fb903b4a8 100644 --- a/lib/utils/include/utils/containers/as_vector.h +++ b/lib/utils/include/utils/containers/vector_of.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AS_VECTOR_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AS_VECTOR_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_OF_H #include namespace FlexFlow { template -std::vector as_vector(C const &c) { +std::vector vector_of(C const &c) { std::vector result(c.cbegin(), c.cend()); return result; } diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index 21a6d28ca2..4170882ae6 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -1,9 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H -#include "fmt/format.h" #include "utils/check_fmtable.h" -#include +#include #include #include @@ -44,15 +43,4 @@ std::ostream &operator<<(std::ostream &s, tl::expected const &t) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(tl::expected const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 8e186928fd..46bf9ca8fa 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -5,7 +5,6 @@ #include "utils/containers/sorted.h" #include "utils/fmt/pair.h" #include "utils/join_strings.h" -#include #include #include @@ -48,15 +47,4 @@ std::ostream &operator<<(std::ostream &s, std::map const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::map const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/multiset.h b/lib/utils/include/utils/fmt/multiset.h index cff150dc29..616b784aac 100644 --- a/lib/utils/include/utils/fmt/multiset.h +++ b/lib/utils/include/utils/fmt/multiset.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -42,15 +41,4 @@ std::ostream &operator<<(std::ostream &s, std::multiset const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::multiset const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 45eebc2c58..2364e49568 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_OPTIONAL_H #include "utils/check_fmtable.h" -#include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::optional const &t) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::optional const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index 6f7e6f6b52..ab5ddd4e28 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #include "utils/check_fmtable.h" -#include #include #include @@ -40,15 +39,4 @@ std::ostream &operator<<(std::ostream &s, std::pair const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::pair const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index 1f8012f240..a183d37542 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -4,7 +4,6 @@ #include "utils/check_fmtable.h" #include "utils/containers/sorted.h" #include "utils/join_strings.h" -#include #include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::set const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::set const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 75bbb4cb8a..876a032fe6 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -5,7 +5,6 @@ #include "utils/fmt/pair.h" #include "utils/join_strings.h" #include -#include #include #include #include @@ -48,15 +47,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_map const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_map const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h index 41abbc925e..deb03a04d4 100644 --- a/lib/utils/include/utils/fmt/unordered_multiset.h +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -42,15 +41,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_multiset const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_multiset const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 646ef0c7c5..257545af1b 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -4,7 +4,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" #include "utils/type_traits_core.h" -#include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_set const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_set const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 867577f72a..06a56417c3 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H -#include #include #include @@ -33,15 +32,4 @@ std::ostream &operator<<(std::ostream &s, std::variant const &v) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::variant const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index 96526175a8..5d9ca0aeae 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -41,15 +40,4 @@ std::ostream &operator<<(std::ostream &s, std::vector const &v) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::vector const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h new file mode 100644 index 0000000000..2ed0bc02be --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h index fc372f68aa..afc9c47c1c 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h @@ -6,6 +6,9 @@ namespace FlexFlow { +std::optional + get_cbc_decomposition_with_edge_order_internal( + DiGraphView const &, std::vector const &); std::optional get_cbc_decomposition(DiGraphView const &); diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h new file mode 100644 index 0000000000..3066886e37 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_COMPLETE_BIPARTITE_COMPOSITE_IS_COMPLETE_BIPARTITE_DIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_COMPLETE_BIPARTITE_COMPOSITE_IS_COMPLETE_BIPARTITE_DIGRAPH_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool is_complete_bipartite_digraph(DiGraphView const &); +bool is_complete_bipartite_digraph(DiGraphView const &, + std::unordered_set const &srcs); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h new file mode 100644 index 0000000000..ee533a1180 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::string digraph_as_dot( + DiGraphView const &, + std::function const &get_node_label); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h new file mode 100644 index 0000000000..87d0d3143a --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_HAS_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_HAS_EDGE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool digraph_has_edge(DiGraphView const &, DirectedEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h new file mode 100644 index 0000000000..6d98c5c20d --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_OUTGOING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_OUTGOING_EDGES_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_outgoing_edges(DiGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h new file mode 100644 index 0000000000..2c48d327c4 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_SUCCESSORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_SUCCESSORS_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_successors(DiGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h b/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h new file mode 100644 index 0000000000..c9751124c8 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_CLOSURE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_CLOSURE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +DiGraphView transitive_closure(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h new file mode 100644 index 0000000000..db2526f973 --- /dev/null +++ b/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_UNDIRECTED_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_UNDIRECTED_GRAPH_H + +#include "utils/graph/node/node_source.h" +#include "utils/graph/undirected/i_undirected_graph.h" + +namespace FlexFlow { + +struct UnorderedSetUndirectedGraph final : public IUndirectedGraph { +public: + UnorderedSetUndirectedGraph(); + + Node add_node() override; + void add_node_unsafe(Node const &) override; + void remove_node_unsafe(Node const &) override; + void add_edge(UndirectedEdge const &) override; + void remove_edge(UndirectedEdge const &) override; + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set + query_edges(UndirectedEdgeQuery const &) const override; + + UnorderedSetUndirectedGraph *clone() const override; + +private: + UnorderedSetUndirectedGraph(NodeSource const &, + std::unordered_set const &, + std::unordered_set const &); + + NodeSource node_source; + std::unordered_set nodes; + std::unordered_set edges; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h index a1d6e9e37a..8306dad1ec 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H -#include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" #include "utils/containers/zip.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index 0b6f348ddf..d5c22e5d3d 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h deleted file mode 100644 index be6b9ce12c..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/optional.h" -#include -#include - -namespace FlexFlow { - -std::optional - get_serial_parallel_decomposition(DiGraphView const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h b/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h deleted file mode 100644 index 6285d7ae1f..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -std::variant - flatten_ast(std::variant const &ast); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h deleted file mode 100644 index 7d8efc96f2..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include - -namespace FlexFlow { - -std::variant internal_to_final_ast( - std::variant const &ast); -SerialParallelDecomposition - to_final_ast(std::variant const &); - -std::unordered_set get_nodes(SerialParallelDecomposition const &sp); -std::unordered_set get_nodes(SerialSplit const &); -std::unordered_set get_nodes(ParallelSplit const &); -std::unordered_set get_nodes(Node const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h new file mode 100644 index 0000000000..b1607e7a76 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include + +namespace FlexFlow { + +BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &, + BinarySPDecompositionTree const &); +BinarySPDecompositionTree + make_parallel_split(BinarySPDecompositionTree const &, + BinarySPDecompositionTree const &); +BinarySPDecompositionTree make_leaf_node(Node const &); + +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); +bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); + +std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml new file mode 100644 index 0000000000..1241311150 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "BinarySPDecompositionTree" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h new file mode 100644 index 0000000000..42d71ce54e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h @@ -0,0 +1,63 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::string format_as(GenericBinarySeriesSplit const &s) { + return fmt::format("", + get_left_child(s), + get_right_child(s)); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinarySeriesSplit const &x) { + return (s << fmt::to_string(x)); +} + +template +std::string format_as(GenericBinaryParallelSplit const &s) { + return fmt::format("", + get_left_child(s), + get_right_child(s)); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinaryParallelSplit const &x) { + return (s << fmt::to_string(x)); +} + +template +std::string format_as(GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return fmt::format("", s); + }, + [](GenericBinaryParallelSplit const &s) { + return fmt::format("", s); + }, + [](T const &t) { + return fmt::format("", t); + }, + }); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinarySPDecompositionTree const &t) { + return (s << fmt::to_string(t)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h new file mode 100644 index 0000000000..74f5ba5d8a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h @@ -0,0 +1,155 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H + +#include +#include +#include + +namespace FlexFlow { + +template +struct GenericBinarySPDecompositionTree; + +template +struct GenericBinarySeriesSplit { +public: + GenericBinarySeriesSplit() = delete; + explicit GenericBinarySeriesSplit( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) + : left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) {} + + GenericBinarySeriesSplit(GenericBinarySeriesSplit const &) = default; + + bool operator==(GenericBinarySeriesSplit const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinarySeriesSplit const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinarySeriesSplit const &other) const { + return this->tie() < other.tie(); + } + +public: + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; + +private: + std::tuple const &, + GenericBinarySPDecompositionTree const &> + tie() const { + return std::tie(*this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct GenericBinaryParallelSplit { +public: + GenericBinaryParallelSplit() = delete; + explicit GenericBinaryParallelSplit( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) + : left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) {} + + GenericBinaryParallelSplit(GenericBinaryParallelSplit const &) = default; + + bool operator==(GenericBinaryParallelSplit const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinaryParallelSplit const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinaryParallelSplit const &other) const { + return this->tie() < other.tie(); + } + +public: + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; + +private: + std::tuple const &, + GenericBinarySPDecompositionTree const &> + tie() const { + return std::tie(*this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct GenericBinarySPDecompositionTree { +public: + GenericBinarySPDecompositionTree() = delete; + explicit GenericBinarySPDecompositionTree( + GenericBinarySeriesSplit const &s) + : root{s} {} + + explicit GenericBinarySPDecompositionTree( + GenericBinaryParallelSplit const &s) + : root{s} {} + + explicit GenericBinarySPDecompositionTree(T const &t) : root{t} {} + + GenericBinarySPDecompositionTree(GenericBinarySPDecompositionTree const &) = + default; + + bool operator==(GenericBinarySPDecompositionTree const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinarySPDecompositionTree const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinarySPDecompositionTree const &other) const { + return this->tie() < other.tie(); + } + +public: + std::variant, GenericBinaryParallelSplit, T> + root; + +private: + std::tuple tie() const { + return std::tie(this->root); + } + + friend std::hash; +}; + +} // namespace FlexFlow + +// namespace rc { +// +// template <> +// struct Arbitrary<::FlexFlow::BinarySeriesSplit> { +// static Gen<::FlexFlow::BinarySeriesSplit> arbitrary(); +// }; +// +// template <> +// struct Arbitrary<::FlexFlow::GenericBinaryParallelSplit> { +// static Gen<::FlexFlow::GenericBinaryParallelSplit> arbitrary(); +// }; +// +// template <> +// struct Arbitrary<::FlexFlow::GenericBinarySPDecompositionTree> { +// static Gen<::FlexFlow::GenericBinarySPDecompositionTree> arbitrary(); +// }; +// +// } // namespace rc + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h new file mode 100644 index 0000000000..c6c1186d3d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +TT const &get(GenericBinarySPDecompositionTree const &t) { + return std::get(t.root); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h new file mode 100644 index 0000000000..51e1e20bac --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H + +#include "utils/containers/multiset_union.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::unordered_multiset + get_leaves(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](T const &t) { return std::unordered_multiset{t}; }, + [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, + [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, + }); +} + +template +std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { + return multiset_union(get_leaves(get_left_child(s)), + get_leaves(get_right_child(s))); +} + +template +std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { + return multiset_union(get_leaves(get_left_child(p)), + get_leaves(get_right_child(p))); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h new file mode 100644 index 0000000000..46a460b64e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinarySeriesSplit const &s) { + return *s.left_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinaryParallelSplit const &p) { + return *p.left_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return get_left_child(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_left_child(p); + }, + [](T const &t) -> GenericBinarySPDecompositionTree { + throw mk_runtime_error( + "get_left_child incorrectly called on leaf node"); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h new file mode 100644 index 0000000000..883acda480 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +SPDecompositionTreeNodeType + get_node_type(GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](GenericBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](GenericBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](T const &) { return SPDecompositionTreeNodeType::NODE; }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h new file mode 100644 index 0000000000..7c6d28d7b4 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { + return visit(tt, + overload{ + [](T const &t) { return 1; }, + [](GenericBinarySeriesSplit const &s) { + return get_num_tree_nodes(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_num_tree_nodes(p); + }, + }); +} + +template +int get_num_tree_nodes(GenericBinarySeriesSplit const &s) { + return 1 + get_num_tree_nodes(get_left_child(s)) + + get_num_tree_nodes(get_right_child(s)); +} + +template +int get_num_tree_nodes(GenericBinaryParallelSplit const &p) { + return 1 + get_num_tree_nodes(get_left_child(p)) + + get_num_tree_nodes(get_right_child(p)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h new file mode 100644 index 0000000000..f0bfba43a2 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinarySeriesSplit const &s) { + return *s.right_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinaryParallelSplit const &p) { + return *p.right_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return get_right_child(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_right_child(p); + }, + [](T const &t) -> GenericBinarySPDecompositionTree { + throw mk_runtime_error( + "get_right_child incorrectly called on leaf node"); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h new file mode 100644 index 0000000000..983dc4a572 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" + +namespace std { + +template +struct hash<::FlexFlow::GenericBinarySeriesSplit> { + size_t operator()(::FlexFlow::GenericBinarySeriesSplit const &s) const { + return get_std_hash(s.tie()); + } +}; + +template +struct hash<::FlexFlow::GenericBinaryParallelSplit> { + size_t operator()(::FlexFlow::GenericBinaryParallelSplit const &s) const { + return get_std_hash(s.tie()); + } +}; + +template +struct hash<::FlexFlow::GenericBinarySPDecompositionTree> { + size_t operator()( + ::FlexFlow::GenericBinarySPDecompositionTree const &s) const { + return get_std_hash(s.tie()); + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h new file mode 100644 index 0000000000..8086f38244 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +bool is_series_split(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative>(t.root); +} + +template +bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative>(t.root); +} + +template +bool is_leaf(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative(t.root); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h new file mode 100644 index 0000000000..3ffa63753a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_left_associative( + GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](T const &) { return true; }, + [](GenericBinarySeriesSplit const &s) { + return !is_series_split(get_right_child(s)) && + is_binary_sp_tree_left_associative(get_left_child(s)) && + is_binary_sp_tree_left_associative(get_right_child(s)); + }, + [](GenericBinaryParallelSplit const &p) { + return !is_parallel_split(get_right_child(p)) && + is_binary_sp_tree_left_associative(get_left_child(p)) && + is_binary_sp_tree_left_associative(get_right_child(p)); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h new file mode 100644 index 0000000000..d88459b432 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_right_associative( + GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](T const &t) { return true; }, + [](GenericBinarySeriesSplit const &s) { + return !is_series_split(get_left_child(s)) && + is_binary_sp_tree_right_associative(get_left_child(s)) && + is_binary_sp_tree_right_associative(get_right_child(s)); + }, + [](GenericBinaryParallelSplit const &p) { + return !is_parallel_split(get_left_child(p)) && + is_binary_sp_tree_right_associative(get_left_child(p)) && + is_binary_sp_tree_right_associative(get_right_child(p)); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h new file mode 100644 index 0000000000..4f1f8266e1 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h @@ -0,0 +1,103 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::GenericBinarySeriesSplit> { + static ::FlexFlow::GenericBinarySeriesSplit from_json(json const &j) { + return ::FlexFlow::GenericBinarySeriesSplit{ + j.at("left_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + j.at("right_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + }; + } + + static void to_json(json &j, + ::FlexFlow::GenericBinarySeriesSplit const &v) { + j["__type"] = "GenericBinarySeriesSplit"; + j["left_child"] = get_left_child(v); + j["right_child"] = get_right_child(v); + } +}; + +template +struct adl_serializer<::FlexFlow::GenericBinaryParallelSplit> { + static ::FlexFlow::GenericBinaryParallelSplit from_json(json const &j) { + return ::FlexFlow::GenericBinaryParallelSplit{ + j.at("left_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + j.at("right_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + }; + } + + static void to_json(json &j, + ::FlexFlow::GenericBinaryParallelSplit const &v) { + j["__type"] = "GenericBinaryParallelSplit"; + j["left_child"] = get_left_child(v); + j["right_child"] = get_right_child(v); + } +}; + +template +struct adl_serializer<::FlexFlow::GenericBinarySPDecompositionTree> { + static ::FlexFlow::GenericBinarySPDecompositionTree + from_json(json const &j) { + std::string key = j.at("type").get(); + + if (key == "series") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get<::FlexFlow::GenericBinarySeriesSplit>(), + }; + } else if (key == "parallel") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get<::FlexFlow::GenericBinaryParallelSplit>(), + }; + } else if (key == "leaf") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get(), + }; + } else { + throw ::FlexFlow::mk_runtime_error( + fmt::format("Unknown json type key: {}", key)); + } + } + + static void + to_json(json &j, + ::FlexFlow::GenericBinarySPDecompositionTree const &v) { + j["__type"] = "GenericBinarySPDecompositionTree"; + ::FlexFlow::visit( + v, + ::FlexFlow::overload{ + [&](::FlexFlow::GenericBinarySeriesSplit const &s) { + j["type"] = "series"; + j["value"] = s; + return std::monostate{}; + }, + [&](::FlexFlow::GenericBinaryParallelSplit const &p) { + j["type"] = "parallel"; + j["value"] = p; + return std::monostate{}; + }, + [&](T const &t) { + j["type"] = "leaf"; + j["value"] = t; + return std::monostate{}; + }, + }); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h new file mode 100644 index 0000000000..f55b71146a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree make_generic_binary_series_split( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{ + lhs, + rhs, + }, + }; +} + +template +GenericBinarySPDecompositionTree make_generic_binary_parallel_split( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{ + lhs, + rhs, + }, + }; +} + +template +GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(T const &t) { + return GenericBinarySPDecompositionTree{t}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h new file mode 100644 index 0000000000..a8de1ee8f8 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" + +namespace FlexFlow { + +template +GenericBinarySeriesSplit const & + require_series(GenericBinarySPDecompositionTree const &t) { + return get>(t); +} + +template +GenericBinaryParallelSplit const & + require_parallel(GenericBinarySPDecompositionTree const &t) { + return get>(t); +} + +template +T const &require_node(GenericBinarySPDecompositionTree const &t) { + return get(t); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h new file mode 100644 index 0000000000..4d7fa05960 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template > +GenericBinarySPDecompositionTree + transform(GenericBinarySPDecompositionTree const &tt, F f) { + return visit>( + tt, + overload{ + [&](GenericBinarySeriesSplit const &s) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{ + transform(get_left_child(s), f), + transform(get_right_child(s), f), + }, + }; + }, + [&](GenericBinaryParallelSplit const &s) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{ + transform(get_left_child(s), f), + transform(get_right_child(s), f), + }, + }; + }, + [&](T const &t) { + return GenericBinarySPDecompositionTree{ + f(t), + }; + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h new file mode 100644 index 0000000000..0d9503e59f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +Result visit(GenericBinarySPDecompositionTree const &tt, F f) { + if (std::holds_alternative>(tt.root)) { + return f(std::get>(tt.root)); + } else if (std::holds_alternative>(tt.root)) { + return f(std::get>(tt.root)); + } else if (std::holds_alternative(tt.root)) { + return f(std::get(tt.root)); + } else { + throw mk_runtime_error( + "Unexpected case in visit(GenericBinarySPDecompositionTree)"); + } + + // return std::visit(tt.root, overload { + // [&](GenericBinarySeriesSplit const &s) -> Result { + // return f(s); + // }, + // [&](GenericBinaryParallelSplit const &p) -> Result { + // return f(p); + // }, + // [&](T const &t) -> Result { + // return f(t); + // }, + // }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..183ece3a89 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEFT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEFT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h new file mode 100644 index 0000000000..f5174aee56 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_NARY_SP_TREE_FROM_BINARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_NARY_SP_TREE_FROM_BINARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +SeriesParallelDecomposition + nary_sp_tree_from_binary(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..e01ec0bdde --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_RIGHT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_RIGHT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h new file mode 100644 index 0000000000..f2a006d899 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GET_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GET_SERIES_PARALLEL_DECOMPOSITION_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/optional.h" +#include +#include + +namespace FlexFlow { + +std::optional + get_series_parallel_decomposition(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h b/lib/utils/include/utils/graph/series_parallel/graph_generation.h similarity index 56% rename from lib/utils/include/utils/graph/serial_parallel/graph_generation.h rename to lib/utils/include/utils/graph/series_parallel/graph_generation.h index fac9c98db2..f18fd63d24 100644 --- a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h +++ b/lib/utils/include/utils/graph/series_parallel/graph_generation.h @@ -1,23 +1,23 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GRAPH_GENERATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GRAPH_GENERATION_H #include "utils/graph/dataflow_graph/dataflow_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext); -void serial_extend(DataflowGraph &g, DataflowGraphView const &ext); +void series_extend(DataflowGraph &g, DataflowGraphView const &ext); -DataflowGraph serial_composition(DataflowGraphView const &g1, +DataflowGraph series_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); DataflowGraph parallel_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); DataflowGraph dataflow_graph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition); + SeriesParallelDecomposition const &sp_decomposition); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h new file mode 100644 index 0000000000..1283a6df3a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +std::variant + flatten_ast(std::variant const &ast); + +std::variant + from_binary_sp_tree(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml similarity index 90% rename from lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml rename to lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml index 08f03ed12a..e7666fcd3f 100644 --- a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml @@ -8,7 +8,7 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/split_type.dtg.h", + "utils/graph/series_parallel/split_type.dtg.h", "", "", "utils/graph/node/node.dtg.h", diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h similarity index 70% rename from lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h rename to lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 71cc5e3998..3fc1347ee5 100644 --- a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -1,8 +1,8 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_REDUCTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_REDUCTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H #include "utils/graph/multidigraph/multidigraph.h" -#include "utils/graph/serial_parallel/parallel_reduction.dtg.h" +#include "utils/graph/series_parallel/parallel_reduction.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/parallel_reduction.struct.toml rename to lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h new file mode 100644 index 0000000000..52d2cb7236 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_PARALLEL_DECOMPOSITION_H + +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +std::variant internal_to_final_ast( + std::variant const &ast); +SeriesParallelDecomposition + to_final_ast(std::variant const &); + +std::unordered_multiset get_nodes(SeriesParallelDecomposition const &sp); +std::unordered_multiset get_nodes(SeriesSplit const &); +std::unordered_multiset get_nodes(ParallelSplit const &); +std::unordered_multiset get_nodes(Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml similarity index 62% rename from lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml rename to lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml index f816abfbb4..921499ebd1 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "SerialParallelDecomposition" +name = "SeriesParallelDecomposition" features = [ "eq", "hash", @@ -7,12 +7,12 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial_parallel_splits.h", + "utils/graph/series_parallel/series_parallel_splits.h", "utils/graph/node/node.dtg.h", ] [[values]] -type = "::FlexFlow::SerialSplit" +type = "::FlexFlow::SeriesSplit" [[values]] type = "::FlexFlow::ParallelSplit" diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h similarity index 59% rename from lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h rename to lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h index 081137e513..18434d2b67 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H #include "utils/graph/node/node.dtg.h" #include @@ -7,18 +7,18 @@ namespace FlexFlow { -struct SerialSplit; +struct SeriesSplit; struct ParallelSplit; -struct SerialSplit { +struct SeriesSplit { public: - SerialSplit() = delete; - explicit SerialSplit(std::vector> const &); - explicit SerialSplit( + SeriesSplit() = delete; + explicit SeriesSplit(std::vector> const &); + explicit SeriesSplit( std::initializer_list> const &); - bool operator==(SerialSplit const &) const; - bool operator!=(SerialSplit const &) const; + bool operator==(SeriesSplit const &) const; + bool operator!=(SeriesSplit const &) const; public: std::vector> children; @@ -28,16 +28,16 @@ struct SerialSplit { Tie tie() const; }; -std::string format_as(SerialSplit const &); -std::ostream &operator<<(std::ostream &, SerialSplit const &); +std::string format_as(SeriesSplit const &); +std::ostream &operator<<(std::ostream &, SeriesSplit const &); } // namespace FlexFlow namespace std { template <> -struct hash<::FlexFlow::SerialSplit> { - size_t operator()(::FlexFlow::SerialSplit const &) const; +struct hash<::FlexFlow::SeriesSplit> { + size_t operator()(::FlexFlow::SeriesSplit const &) const; }; } // namespace std @@ -48,15 +48,15 @@ struct ParallelSplit { public: ParallelSplit() = delete; explicit ParallelSplit( - std::unordered_set> const &); + std::unordered_multiset> const &); explicit ParallelSplit( - std::initializer_list> const &); + std::initializer_list> const &); bool operator==(ParallelSplit const &) const; bool operator!=(ParallelSplit const &) const; public: - std::unordered_set> children; + std::unordered_multiset> children; private: using Tie = std::tuple; diff --git a/lib/utils/include/utils/graph/serial_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h similarity index 77% rename from lib/utils/include/utils/graph/serial_parallel/series_reduction.h rename to lib/utils/include/utils/graph/series_parallel/series_reduction.h index c9bae58546..a7d53fecfc 100644 --- a/lib/utils/include/utils/graph/serial_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -1,9 +1,9 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIES_REDUCTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIES_REDUCTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_REDUCTION_H #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" -#include "utils/graph/serial_parallel/series_reduction.dtg.h" +#include "utils/graph/series_parallel/series_reduction.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/series_reduction.struct.toml rename to lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/serial_parallel/sink_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/sink_settings.enum.toml rename to lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml diff --git a/lib/utils/include/utils/graph/serial_parallel/source_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/source_settings.enum.toml rename to lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml diff --git a/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml new file mode 100644 index 0000000000..2050800cbd --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SPDecompositionTreeNodeType" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "SERIES" + +[[values]] +name = "PARALLEL" + +[[values]] +name = "NODE" diff --git a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml similarity index 90% rename from lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml rename to lib/utils/include/utils/graph/series_parallel/split_type.enum.toml index 96d85f0e12..c1a1cb5978 100644 --- a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml +++ b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml @@ -8,7 +8,7 @@ features = [ ] [[values]] -name = "SERIAL" +name = "SERIES" [[values]] name = "PARALLEL" diff --git a/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h b/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h new file mode 100644 index 0000000000..3e951b1db1 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_EDGES_H + +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_edges(UndirectedGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h b/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h new file mode 100644 index 0000000000..bc605360d2 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_NEIGHBORING_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_NEIGHBORING_NODES_H + +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_neighboring_nodes(UndirectedGraphView const &, + Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h index 1662ec6d8c..4761275031 100644 --- a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h @@ -15,7 +15,7 @@ struct IUndirectedGraph : public IUndirectedGraphView { virtual std::unordered_set query_nodes(NodeQuery const &query) const = 0; - virtual IUndirectedGraph *clone() const override = 0; + virtual IUndirectedGraph *clone() const = 0; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h index 9aa0f189ec..65939acc87 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h @@ -1,11 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H +#include "utils/graph/undirected/undirected_edge.h" #include "utils/graph/undirected/undirected_edge_query.dtg.h" namespace FlexFlow { UndirectedEdgeQuery undirected_edge_query_all(); +bool matches_edge(UndirectedEdgeQuery const &, UndirectedEdge const &); UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &, UndirectedEdgeQuery const &); diff --git a/lib/utils/include/utils/hash/multiset.h b/lib/utils/include/utils/hash/multiset.h new file mode 100644 index 0000000000..4695b89165 --- /dev/null +++ b/lib/utils/include/utils/hash/multiset.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MULTISET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::multiset const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/unordered_multiset.h b/lib/utils/include/utils/hash/unordered_multiset.h new file mode 100644 index 0000000000..b19c76bfef --- /dev/null +++ b/lib/utils/include/utils/hash/unordered_multiset.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MULTISET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::unordered_multiset const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/json/check_is_jsonable.h b/lib/utils/include/utils/json/check_is_jsonable.h new file mode 100644 index 0000000000..41a64a1b83 --- /dev/null +++ b/lib/utils/include/utils/json/check_is_jsonable.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSONABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSONABLE_H + +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSONABLE(TYPENAME) \ + static_assert(is_json_serializable::value, \ + #TYPENAME " should be json serializeable"); \ + static_assert(is_json_deserializable::value, \ + #TYPENAME " should be json deserializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_json_deserializable.h b/lib/utils/include/utils/json/is_json_deserializable.h new file mode 100644 index 0000000000..9e6625428b --- /dev/null +++ b/lib/utils/include/utils/json/is_json_deserializable.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_DESERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_DESERIALIZABLE_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +struct is_json_deserializable : std::false_type {}; + +template +struct is_json_deserializable< + T, + void_t().get())>> + : std::true_type {}; + +template +inline constexpr bool is_json_deserializable_v = + is_json_deserializable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_json_serializable.h b/lib/utils/include/utils/json/is_json_serializable.h new file mode 100644 index 0000000000..926a8037d4 --- /dev/null +++ b/lib/utils/include/utils/json/is_json_serializable.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_SERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_SERIALIZABLE_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +struct is_json_serializable : std::false_type {}; + +template +struct is_json_serializable< + T, + void_t() = std::declval())>> + : std::true_type {}; + +template +inline constexpr bool is_json_serializable_v = is_json_serializable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_jsonable.h b/lib/utils/include/utils/json/is_jsonable.h new file mode 100644 index 0000000000..2c8c103650 --- /dev/null +++ b/lib/utils/include/utils/json/is_jsonable.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSONABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSONABLE_H + +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +template +struct is_jsonable + : std::conjunction, is_json_deserializable> {}; + +template +inline constexpr bool is_jsonable_v = is_jsonable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/optional.h b/lib/utils/include/utils/json/optional.h new file mode 100644 index 0000000000..c88dd24a15 --- /dev/null +++ b/lib/utils/include/utils/json/optional.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_OPTIONAL_H + +#include "utils/json/is_jsonable.h" +#include +#include + +namespace nlohmann { + +template +struct adl_serializer< + std::optional, + typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { + static void to_json(json &j, std::optional const &t) { + if (t.has_value()) { + j = t.value(); + } else { + j = nullptr; + } + } + + static void from_json(json const &j, std::optional &t) { + if (j == nullptr) { + t = std::nullopt; + } else { + t = j.get(); + } + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json/variant.h b/lib/utils/include/utils/json/variant.h new file mode 100644 index 0000000000..fe2c3f3b6c --- /dev/null +++ b/lib/utils/include/utils/json/variant.h @@ -0,0 +1,89 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H + +#include "utils/json/is_jsonable.h" +#include + +namespace FlexFlow { + +struct VariantToJsonFunctor { + VariantToJsonFunctor(nlohmann::json &j) : j(j) {} + + nlohmann::json &j; + + template + void operator()(T const &t) { + static_assert(is_jsonable::value, ""); + + j = t; + } +}; + +template +void variant_to_json(json &j, std::variant const &v) { + json jval; + visit(::FlexFlow::VariantToJsonFunctor{jval}, v); + j["value"] = jval; + j["index"] = v.index(); +} + +template +std::optional variant_from_json_impl(json const &j) { + using Type = typename std::variant_alternative::type; + + if (j.at("index").get() == Idx) { + return j.at("value").get(); + } + return std::nullopt; +} + +template +std::optional variant_from_json_impl(json const &j, + std::index_sequence) { + // If there were no errors when parsing, all but one element of the array + // will be nullopt. This is because each call to variant_from_json_impl will + // have a unique index and exactly one of them will match the index in the + // json object. + std::array, sizeof...(Is)> results{ + variant_from_json_impl(j)...}; + for (std::optional &maybe : results) { + if (maybe) { + return maybe.value(); + } + } + return std::nullopt; +} + +template +std::variant variant_from_json(json const &j) { + using Variant = std::variant; + std::optional result = variant_from_json_impl( + j, std::make_index_sequence()); + if (!result.has_value()) { + throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", + j.at("index").get()); + } + return result.value(); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer, + typename std::enable_if<::FlexFlow::elements_satisfy< + ::FlexFlow::is_json_serializable, + std::variant>::value>::type> { + static void to_json(json &j, std::variant const &v) { + return ::FlexFlow::variant_to_json(j, v); + } + + static std::variant from_json(json const &j) { + return ::FlexFlow::variant_from_json(j); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json/visitable.h similarity index 52% rename from lib/utils/include/utils/json.h rename to lib/utils/include/utils/json/visitable.h index f56917e329..abc20065de 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json/visitable.h @@ -1,6 +1,9 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" +#include "utils/json/is_jsonable.h" #include "utils/json_core.h" #include "utils/optional.h" #include "utils/sequence.h" @@ -10,33 +13,6 @@ namespace FlexFlow { -template -struct is_json_serializable : std::false_type {}; - -template -struct is_json_serializable< - T, - void_t() = std::declval())>> - : std::true_type {}; - -template -struct is_json_deserializable : std::false_type {}; - -template -struct is_json_deserializable().get())>> - : std::true_type {}; - -template -struct is_jsonable - : conjunction, is_json_deserializable> {}; - -#define CHECK_IS_JSONABLE(TYPENAME) \ - static_assert(is_json_serializable::value, \ - #TYPENAME " should be json serializeable"); \ - static_assert(is_json_deserializable::value, \ - #TYPENAME " should be json deserializeable") - struct json_serialization_visitor { json_serialization_visitor() = delete; json_serialization_visitor(json &j) : j(j) {} @@ -134,66 +110,6 @@ T moveonly_visit_json_deserialize(json const &j) { return visitable_from_tuple(tuple_from_json(j)); } -struct VariantToJsonFunctor { - VariantToJsonFunctor(json &j) : j(j) {} - - json &j; - - template - void operator()(T const &t) { - static_assert(is_jsonable::value, ""); - - j = t; - } -}; - -template -void variant_to_json(json &j, std::variant const &v) { - json jval; - visit(::FlexFlow::VariantToJsonFunctor{jval}, v); - j["value"] = jval; - j["index"] = v.index(); -} - -template -std::optional variant_from_json_impl(json const &j) { - using Type = typename std::variant_alternative::type; - - if (j.at("index").get() == Idx) { - return j.at("value").get(); - } - return std::nullopt; -} - -template -std::optional variant_from_json_impl(json const &j, - std::index_sequence) { - // If there were no errors when parsing, all but one element of the array - // will be nullopt. This is because each call to variant_from_json_impl will - // have a unique index and exactly one of them will match the index in the - // json object. - std::array, sizeof...(Is)> results{ - variant_from_json_impl(j)...}; - for (std::optional &maybe : results) { - if (maybe) { - return maybe.value(); - } - } - return std::nullopt; -} - -template -std::variant variant_from_json(json const &j) { - using Variant = std::variant; - std::optional result = variant_from_json_impl( - j, std::make_index_sequence()); - if (!result.has_value()) { - throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("index").get()); - } - return result.value(); -} - } // namespace FlexFlow namespace nlohmann { @@ -231,41 +147,6 @@ struct adl_serializer< } }; -template -struct adl_serializer< - std::optional, - typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { - static void to_json(json &j, std::optional const &t) { - if (t.has_value()) { - to_json(j, t.value()); - } else { - j = nullptr; - } - } - - static void from_json(json const &j, std::optional &t) { - if (j == nullptr) { - t = std::nullopt; - } else { - t = j.get(); - } - } -}; - -template -struct adl_serializer, - typename std::enable_if<::FlexFlow::elements_satisfy< - ::FlexFlow::is_json_serializable, - std::variant>::value>::type> { - static void to_json(json &j, std::variant const &v) { - return ::FlexFlow::variant_to_json(j, v); - } - - static std::variant from_json(json const &j) { - return ::FlexFlow::variant_from_json(j); - } -}; - } // namespace nlohmann #endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3448ec4e0e..3ec165d595 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -7,6 +7,15 @@ namespace FlexFlow { +template +T or_else(std::optional const &o, F &&f) { + if (o.has_value()) { + return o.value(); + } else { + return f(); + } +} + template T const &unwrap(std::optional const &o, F const &f) { if (o.has_value()) { @@ -25,18 +34,4 @@ T const &assert_unwrap(std::optional const &o) { } // namespace FlexFlow -namespace rc { - -template -struct Arbitrary> { - static Gen> arbitrary() { - return gen::map( - gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { - return m ? std::optional(std::move(*m)) : std::optional(); - }); - } -}; - -} // namespace rc - #endif diff --git a/lib/utils/include/utils/rapidcheck/optional.h b/lib/utils/include/utils/rapidcheck/optional.h new file mode 100644 index 0000000000..edb28fdb81 --- /dev/null +++ b/lib/utils/include/utils/rapidcheck/optional.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_OPTIONAL_H + +#include +#include + +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::map( + gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { + return m ? std::optional(std::move(*m)) : std::optional(); + }); + } +}; + +} // namespace rc + +#endif diff --git a/lib/utils/include/utils/required.h b/lib/utils/include/utils/required.h index 9cdd7918dd..d16b67ba86 100644 --- a/lib/utils/include/utils/required.h +++ b/lib/utils/include/utils/required.h @@ -1,9 +1,13 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_H -#include "utils/json.h" +#include "utils/fmt/vector.h" +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" +#include "utils/json/is_jsonable.h" #include "utils/required_core.h" #include "utils/type_traits.h" +#include namespace FlexFlow { @@ -14,11 +18,11 @@ static_assert(is_list_initializable, int>::value, ""); namespace nlohmann { template struct adl_serializer<::FlexFlow::req> { - static ::FlexFlow::req from_json(json const &j) { + static ::FlexFlow::req from_json(nlohmann::json const &j) { return {j.template get()}; } - static void to_json(json &j, ::FlexFlow::req const &t) { + static void to_json(nlohmann::json &j, ::FlexFlow::req const &t) { j = static_cast(t); } }; diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 19743b8301..7a936ebd7b 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -4,9 +4,9 @@ #include "fmt/core.h" #include "stack_vector.h" #include "utils/fmt.h" -#include "utils/json.h" #include "utils/type_traits.h" #include +#include #include #include @@ -70,13 +70,13 @@ template using stack_string = stack_basic_string; template -void to_json(json &j, stack_string const &v) { +void to_json(nlohmann::json &j, stack_string const &v) { std::string as_string = v; j = as_string; } template -void from_json(json const &j, stack_string &v) { +void from_json(nlohmann::json const &j, stack_string &v) { std::string as_string; j.get_to(as_string); v = stack_string{as_string}; diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 1d654e3415..7a7bce7afc 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -3,12 +3,12 @@ #include "utils/hash-utils.h" #include "utils/join_strings.h" -#include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" #include #include #include +#include #include #include #include @@ -326,13 +326,13 @@ std::ostream &operator<<(std::ostream &s, stack_vector const &v) { } template -void to_json(json &j, stack_vector const &v) { +void to_json(nlohmann::json &j, stack_vector const &v) { std::vector as_vec(v.begin(), v.end()); j = as_vec; } template -void from_json(json const &j, stack_vector &v) { +void from_json(nlohmann::json const &j, stack_vector &v) { std::vector as_vec; j.get_to(as_vec); v = stack_vector{as_vec.begin(), as_vec.end()}; diff --git a/lib/utils/src/utils/cli/cli_get_help_message.cc b/lib/utils/src/utils/cli/cli_get_help_message.cc new file mode 100644 index 0000000000..03c53c9356 --- /dev/null +++ b/lib/utils/src/utils/cli/cli_get_help_message.cc @@ -0,0 +1,101 @@ +#include "utils/cli/cli_get_help_message.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/maximum.h" +#include "utils/containers/transform.h" +#include "utils/integer_conversions.h" +#include "utils/join_strings.h" +#include + +namespace FlexFlow { + +std::string cli_get_help_message(std::string const &program_name, + CLISpec const &cli) { + auto render_pos_arg = [](CLIPositionalArgumentSpec const &pos_arg_spec) { + if (pos_arg_spec.choices.has_value()) { + return "{" + join_strings(pos_arg_spec.choices.value(), ",") + "}"; + } else { + return pos_arg_spec.name; + } + }; + + auto render_flag_option_column_key = [](CLIFlagSpec const &flag_spec) { + std::ostringstream oss; + if (flag_spec.short_flag.has_value()) { + oss << "-" << flag_spec.short_flag.value() << ", "; + } + oss << "--" << flag_spec.long_flag; + return oss.str(); + }; + + std::ostringstream oss; + + oss << "usage: " << program_name; + for (CLIFlagSpec const &flag_spec : cli.flags) { + if (flag_spec.short_flag.has_value()) { + oss << " [-" << flag_spec.short_flag.value() << "]"; + } else { + oss << " [--" << flag_spec.long_flag << "]"; + } + } + for (CLIPositionalArgumentSpec const &pos_arg_spec : + cli.positional_arguments) { + oss << " " << render_pos_arg(pos_arg_spec); + } + + oss << std::endl; + + std::vector all_arg_columns = concat_vectors(std::vector{ + transform(cli.positional_arguments, render_pos_arg), + transform(cli.flags, render_flag_option_column_key), + }); + std::vector all_arg_column_widths = + transform(all_arg_columns, [](std::string const &s) { return s.size(); }); + + if (!all_arg_columns.empty()) { + int max_column_width = + std::min(int_from_size_t(maximum(all_arg_column_widths).value()), 20); + + auto render_column = [&](std::string const &key, + std::optional const &description) { + if (description.has_value()) { + if (key.size() > max_column_width) { + return " " + key + "\n" + std::string(24, ' ') + description.value(); + } else { + } + return fmt::format( + " {:<{}} {}", key, max_column_width, description.value()); + } else { + return fmt::format(" {}", key); + } + }; + + if (!cli.positional_arguments.empty()) { + oss << std::endl; + oss << "positional arguments:" << std::endl; + + if (!cli.positional_arguments.empty()) { + for (CLIPositionalArgumentSpec const &pos_arg_spec : + cli.positional_arguments) { + oss << render_column(render_pos_arg(pos_arg_spec), + pos_arg_spec.description) + << std::endl; + } + } + } + + if (!cli.flags.empty()) { + oss << std::endl; + oss << "options:" << std::endl; + + for (CLIFlagSpec const &flag_spec : cli.flags) { + oss << render_column(render_flag_option_column_key(flag_spec), + flag_spec.description) + << std::endl; + } + } + } + + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse.cc b/lib/utils/src/utils/cli/cli_parse.cc new file mode 100644 index 0000000000..07982c0c2d --- /dev/null +++ b/lib/utils/src/utils/cli/cli_parse.cc @@ -0,0 +1,96 @@ +#include "utils/cli/cli_parse.h" +#include "utils/cli/cli_spec.h" +#include "utils/containers/contains.h" +#include "utils/containers/enumerate.h" +#include "utils/containers/generate_map.h" + +namespace FlexFlow { + +tl::expected cli_parse_flag(CLISpec const &cli, + std::string const &arg) { + for (auto const &[idx, flag_spec] : enumerate(cli.flags)) { + CLIFlagKey key = CLIFlagKey{idx}; + if (("--" + flag_spec.long_flag) == arg) { + return key; + } + + if (flag_spec.short_flag.has_value()) { + if ((std::string{"-"} + flag_spec.short_flag.value()) == arg) { + return key; + } + } + } + + return tl::unexpected(fmt::format("Encountered unknown flag {}", arg)); +} + +tl::expected + cli_parse(CLISpec const &cli, std::vector const &args) { + CLIParseResult result = CLIParseResult{ + generate_map(cli_get_flag_keys(cli), + [](CLIFlagKey const &) { return false; }), + {}, + }; + + int consumed_positional_args = 0; + auto parse_positional_arg = + [&](std::string const &arg) -> std::optional { + if (consumed_positional_args >= cli.positional_arguments.size()) { + return fmt::format("Too many positional arguments: expected {}", + cli.positional_arguments.size()); + } + + CLIPositionalArgumentSpec arg_spec = + cli.positional_arguments.at(consumed_positional_args); + + if (arg_spec.choices.has_value() && + !contains(arg_spec.choices.value(), arg)) { + return fmt::format( + "Invalid option for positional argument \"{}\": \"{}\"", + arg_spec.name, + arg); + } + + result.positional_arguments.insert( + {CLIPositionalArgumentKey{consumed_positional_args}, arg}); + consumed_positional_args++; + + return std::nullopt; + }; + + for (int i = 1; i < args.size(); i++) { + std::string arg = args.at(i); + + if (!arg.empty() && arg.at(0) == '-') { + tl::expected parsed_flag = + cli_parse_flag(cli, arg); + + if (parsed_flag.has_value()) { + result.flags.at(parsed_flag.value()) = true; + } + } else { + std::optional maybe_err_msg = parse_positional_arg(arg); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + } + + if (consumed_positional_args != cli.positional_arguments.size()) { + return tl::unexpected( + fmt::format("Not enough positional arguments: found {}, expected {}", + consumed_positional_args, + cli.positional_arguments.size())); + } + + return result; +} + +tl::expected + cli_parse(CLISpec const &cli, int argc, char const *const *argv) { + std::vector args = {argv, argv + argc}; + + return cli_parse(cli, args); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse_result.cc b/lib/utils/src/utils/cli/cli_parse_result.cc new file mode 100644 index 0000000000..6682a7a6eb --- /dev/null +++ b/lib/utils/src/utils/cli/cli_parse_result.cc @@ -0,0 +1,14 @@ +#include "utils/cli/cli_parse_result.h" + +namespace FlexFlow { + +bool cli_get_flag(CLIParseResult const &result, CLIArgumentKey const &key) { + return result.flags.at(key.get()); +} + +std::string cli_get_argument(CLIParseResult const &result, + CLIArgumentKey const &key) { + return result.positional_arguments.at(key.get()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_spec.cc b/lib/utils/src/utils/cli/cli_spec.cc new file mode 100644 index 0000000000..ca51cfe57f --- /dev/null +++ b/lib/utils/src/utils/cli/cli_spec.cc @@ -0,0 +1,37 @@ +#include "utils/cli/cli_spec.h" +#include "utils/containers/count.h" +#include "utils/containers/transform.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +CLISpec empty_cli_spec() { + return CLISpec{{}, {}}; +} + +std::vector cli_get_flag_keys(CLISpec const &cli) { + return transform(count(cli.flags.size()), + [](int idx) { return CLIFlagKey{idx}; }); +} + +CLIArgumentKey cli_add_help_flag(CLISpec &cli) { + CLIFlagSpec help_flag = + CLIFlagSpec{"help", 'h', "show this help message and exit"}; + return cli_add_flag(cli, help_flag); +} + +CLIArgumentKey cli_add_flag(CLISpec &cli, CLIFlagSpec const &flag_spec) { + cli.flags.push_back(flag_spec); + + return CLIArgumentKey{CLIFlagKey{int_from_size_t(cli.flags.size()) - 1}}; +} + +CLIArgumentKey + cli_add_positional_argument(CLISpec &cli, + CLIPositionalArgumentSpec const &arg) { + cli.positional_arguments.push_back(arg); + return CLIArgumentKey{CLIPositionalArgumentKey{ + int_from_size_t(cli.positional_arguments.size()) - 1}}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/as_vector.cc b/lib/utils/src/utils/containers/as_vector.cc deleted file mode 100644 index 9c7b63ca58..0000000000 --- a/lib/utils/src/utils/containers/as_vector.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/as_vector.h" diff --git a/lib/utils/src/utils/containers/enumerate_vector.cc b/lib/utils/src/utils/containers/enumerate_vector.cc new file mode 100644 index 0000000000..d4fd131af2 --- /dev/null +++ b/lib/utils/src/utils/containers/enumerate_vector.cc @@ -0,0 +1 @@ +#include "utils/containers/enumerate_vector.h" diff --git a/lib/utils/src/utils/containers/foldl1.cc b/lib/utils/src/utils/containers/foldl1.cc new file mode 100644 index 0000000000..c6cdd0eec9 --- /dev/null +++ b/lib/utils/src/utils/containers/foldl1.cc @@ -0,0 +1 @@ +#include "utils/containers/foldl1.h" diff --git a/lib/utils/src/utils/containers/foldr1.cc b/lib/utils/src/utils/containers/foldr1.cc new file mode 100644 index 0000000000..9d00d81565 --- /dev/null +++ b/lib/utils/src/utils/containers/foldr1.cc @@ -0,0 +1 @@ +#include "utils/containers/foldr1.h" diff --git a/lib/utils/src/utils/containers/get_element_counts.cc b/lib/utils/src/utils/containers/get_element_counts.cc index 9840ed34d8..ac8e289523 100644 --- a/lib/utils/src/utils/containers/get_element_counts.cc +++ b/lib/utils/src/utils/containers/get_element_counts.cc @@ -1,10 +1,10 @@ #include "utils/containers/get_element_counts.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { std::unordered_map get_element_counts(std::string const &s) { - return get_element_counts(as_vector(s)); + return get_element_counts(vector_of(s)); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/maximum.cc b/lib/utils/src/utils/containers/maximum.cc new file mode 100644 index 0000000000..51d92cf951 --- /dev/null +++ b/lib/utils/src/utils/containers/maximum.cc @@ -0,0 +1 @@ +#include "utils/containers/maximum.h" diff --git a/lib/utils/src/utils/containers/multiset_union.cc b/lib/utils/src/utils/containers/multiset_union.cc new file mode 100644 index 0000000000..a053d05fa6 --- /dev/null +++ b/lib/utils/src/utils/containers/multiset_union.cc @@ -0,0 +1 @@ +#include "utils/containers/multiset_union.h" diff --git a/lib/utils/src/utils/containers/require_no_duplicates.cc b/lib/utils/src/utils/containers/require_no_duplicates.cc new file mode 100644 index 0000000000..b1d21ad832 --- /dev/null +++ b/lib/utils/src/utils/containers/require_no_duplicates.cc @@ -0,0 +1 @@ +#include "utils/containers/require_no_duplicates.h" diff --git a/lib/utils/src/utils/containers/set_of.cc b/lib/utils/src/utils/containers/set_of.cc new file mode 100644 index 0000000000..3a12ee539d --- /dev/null +++ b/lib/utils/src/utils/containers/set_of.cc @@ -0,0 +1 @@ +#include "utils/containers/set_of.h" diff --git a/lib/utils/src/utils/containers/to_uppercase.cc b/lib/utils/src/utils/containers/to_uppercase.cc new file mode 100644 index 0000000000..6c02b5a109 --- /dev/null +++ b/lib/utils/src/utils/containers/to_uppercase.cc @@ -0,0 +1,10 @@ +#include "utils/containers/to_uppercase.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::string to_uppercase(std::string const &s) { + return transform(s, [](char c) -> char { return std::toupper(c); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/vector_of.cc b/lib/utils/src/utils/containers/vector_of.cc new file mode 100644 index 0000000000..b997076511 --- /dev/null +++ b/lib/utils/src/utils/containers/vector_of.cc @@ -0,0 +1 @@ +#include "utils/containers/vector_of.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 323f444a22..6ed41daf43 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -219,10 +219,6 @@ std::unordered_set get_endpoints(UndirectedEdge const &e) { // return g.query_edges(MultiDiEdgeQuery::all()); // } -std::unordered_set get_edges(UndirectedGraphView const &g) { - return g.query_edges(undirected_edge_query_all()); -} - // std::unordered_set get_edges(OpenMultiDiGraphView const &g) // { // return g.query_edges(OpenMultiDiEdgeQuery::all()); diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..d17a84dd12 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + + std::unordered_set all_nodes = get_nodes(g); + query_set src_query = query_set{set_minus(all_nodes, ns)}; + + DataflowEdgeQuery query = DataflowEdgeQuery{ + src_query, + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 011d8b3ed9..8afe7da926 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -3,8 +3,13 @@ #include "utils/containers/extend.h" #include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/set.h" +#include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" #include "utils/graph/digraph/algorithms/get_incoming_edges.h" #include "utils/graph/digraph/algorithms/get_outgoing_edges.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" @@ -12,23 +17,35 @@ #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "utils/graph/node/algorithms.h" #include "utils/hash/unordered_set.h" +#include namespace FlexFlow { std::optional - get_cbc_decomposition(DiGraphView const &g) { + get_cbc_decomposition_with_edge_order_internal( + DiGraphView const &g, std::vector const &edge_order) { // implementation of the algorithm from https://doi.org/10.1145/800135.804393 // top left of page 8, second paragraph + std::queue edges_to_process; + for (DirectedEdge const &e : edge_order) { + edges_to_process.push(e); + } + std::unordered_set already_in_a_head = {}; std::unordered_set already_in_a_tail = {}; - std::unordered_set edges_to_process = get_edges(g); + + std::unordered_set already_processed = {}; CompleteBipartiteCompositeDecomposition result = CompleteBipartiteCompositeDecomposition{{}}; while (!edges_to_process.empty()) { - DirectedEdge e = get_first(edges_to_process); + DirectedEdge e = edges_to_process.front(); + edges_to_process.pop(); + if (contains(already_processed, e)) { + continue; + } std::unordered_set head = get_predecessors(g, e.dst); std::unordered_set tail = get_successors(g, e.src); @@ -39,6 +56,12 @@ std::optional std::unordered_set from_head_to_tail = g.query_edges(DirectedEdgeQuery{head, tail}); + + DiGraphView subgraph = get_subgraph(g, set_union(head, tail)); + if (!is_complete_bipartite_digraph(subgraph, head)) { + return std::nullopt; + } + if (set_union(values(get_outgoing_edges(g, head))) != from_head_to_tail) { return std::nullopt; } @@ -47,7 +70,7 @@ std::optional } result.subgraphs.insert(BipartiteComponent{head, tail}); - edges_to_process = set_minus(edges_to_process, from_head_to_tail); + already_processed = set_union(already_processed, from_head_to_tail); extend(already_in_a_head, head); extend(already_in_a_tail, tail); } @@ -58,4 +81,10 @@ std::optional return result; } +std::optional + get_cbc_decomposition(DiGraphView const &g) { + std::vector edge_order = vector_of(get_edges(g)); + return get_cbc_decomposition_with_edge_order_internal(g, edge_order); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc new file mode 100644 index 0000000000..2eab8371b2 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -0,0 +1,29 @@ +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" +#include "utils/containers/get_first.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +bool is_complete_bipartite_digraph(DiGraphView const &g) { + return is_complete_bipartite_digraph(g, get_sources(g)); +} + +bool is_complete_bipartite_digraph(DiGraphView const &g, + std::unordered_set const &srcs) { + std::unordered_set sinks = set_minus(get_nodes(g), srcs); + + std::unordered_set edges = get_edges(g); + + std::unordered_set expected_edges; + for (Node const &src : srcs) { + for (Node const &sink : sinks) { + expected_edges.insert(DirectedEdge{src, sink}); + } + } + + return edges == expected_edges; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc new file mode 100644 index 0000000000..ad7830cc76 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc @@ -0,0 +1,32 @@ +#include "utils/graph/digraph/algorithms/digraph_as_dot.h" +#include "utils/dot_file.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::string digraph_as_dot( + DiGraphView const &g, + std::function const &get_node_label) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + auto get_node_name = [](Node const &n) { + return fmt::format("n{}", n.raw_uid); + }; + + for (Node const &n : get_nodes(g)) { + RecordFormatter rec; + rec << get_node_label(n); + dot.add_record_node(get_node_name(n), rec); + } + + for (DirectedEdge const &e : get_edges(g)) { + dot.add_edge(get_node_name(e.src), get_node_name(e.dst)); + } + + dot.close(); + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc new file mode 100644 index 0000000000..5c790abb8c --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc @@ -0,0 +1,13 @@ +#include "utils/graph/digraph/algorithms/digraph_has_edge.h" + +namespace FlexFlow { + +bool digraph_has_edge(DiGraphView const &g, DirectedEdge const &e) { + return !g.query_edges(DirectedEdgeQuery{ + query_set{e.src}, + query_set{e.dst}, + }) + .empty(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc index 2e570cbdf9..34cc7fcc6f 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/get_imm_dominators_map.h" -#include "utils/containers/as_vector.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/filter_values.h" #include "utils/containers/generate_map.h" @@ -7,6 +6,7 @@ #include "utils/containers/get_only.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms/get_dominators_map.h" #include "utils/graph/node/algorithms.h" @@ -22,8 +22,8 @@ std::unordered_map> std::unordered_set n_dominators = node_to_its_dominators.at(n); n_dominators.erase(n); std::vector recursive_dominator_list = concat_vectors( - transform(as_vector(n_dominators), [&](Node const &dominator) { - return as_vector(node_to_its_dominators.at(dominator)); + transform(vector_of(n_dominators), [&](Node const &dominator) { + return vector_of(node_to_its_dominators.at(dominator)); })); std::unordered_map dominator_counts = get_element_counts(recursive_dominator_list); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc new file mode 100644 index 0000000000..f19deb3046 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc @@ -0,0 +1,16 @@ +#include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set get_subgraph_outgoing_edges( + DiGraphView const &g, std::unordered_set const &subgraph_nodes) { + std::unordered_set external_nodes = + set_minus(get_nodes(g), subgraph_nodes); + DirectedEdgeQuery query = DirectedEdgeQuery{query_set{subgraph_nodes}, + query_set{external_nodes}}; + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc new file mode 100644 index 0000000000..e860fb11b1 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc @@ -0,0 +1,16 @@ +#include "utils/graph/digraph/algorithms/get_subgraph_successors.h" +#include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_successors(DiGraphView const &g, + std::unordered_set const &subgraph_nodes) { + std::unordered_set successors = + transform(get_subgraph_outgoing_edges(g, subgraph_nodes), + [](DirectedEdge const &e) { return e.dst; }); + + return successors; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc new file mode 100644 index 0000000000..3efea1c138 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc @@ -0,0 +1,51 @@ +#include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/digraph_has_edge.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +DiGraphView transitive_closure(DiGraphView const &g) { + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + std::unordered_set edges = get_edges(g); + + int num_nodes = nodes.size(); + + std::vector edge_matrix(num_nodes * num_nodes, false); + + auto has_edge = [&](int src_idx, + int dst_idx) -> std::vector::reference { + return edge_matrix[src_idx * num_nodes + dst_idx]; + }; + + for (DirectedEdge const &e : get_edges(g)) { + has_edge(nodes.at_r(e.src), nodes.at_r(e.dst)) = true; + } + + DiGraph result = materialize_digraph_view(g); + for (int k = 0; k < num_nodes; k++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, k)) { + for (int j = 0; j < num_nodes; j++) { + if (has_edge(k, j)) { + has_edge(i, j) = true; + result.add_edge(DirectedEdge{nodes.at_l(i), nodes.at_l(j)}); + } + } + } + } + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc index 10ffe4fc33..97a2439263 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,7 +1,12 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_closure.h" #include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -24,29 +29,60 @@ DirectedEdgeMaskView *DirectedEdgeMaskView::clone() const { } DiGraphView transitive_reduction(DiGraphView const &g) { - std::unordered_set edge_mask = get_edges(g); + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + // + // transitive_closure inlined to avoid any drifts in node numbering + // between transitive_closure and transitive_reduction + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + int num_nodes = nodes.size(); + + std::vector edge_matrix(num_nodes * num_nodes, false); + + auto has_edge = [&](int src_idx, + int dst_idx) -> std::vector::reference { + return edge_matrix[src_idx * num_nodes + dst_idx]; + }; + + for (DirectedEdge const &e : get_edges(g)) { + has_edge(nodes.at_r(e.src), nodes.at_r(e.dst)) = true; + } - while (true) { - std::unordered_set new_edge_mask = edge_mask; - for (DirectedEdge const &e1 : edge_mask) { - for (DirectedEdge const &e2 : edge_mask) { - if (e1.dst == e2.src && e1 != e2) { - DirectedEdge trans_edge = DirectedEdge{e1.src, e2.dst}; - if (contains(new_edge_mask, trans_edge)) { - new_edge_mask.erase(trans_edge); + // compute transitive closure + // see https://cs.winona.edu/lin/cs440/ch08-2.pdf slide 8-8 + for (int k = 0; k < num_nodes; k++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, k)) { + for (int j = 0; j < num_nodes; j++) { + if (has_edge(k, j)) { + has_edge(i, j) = true; } } } } + } - if (new_edge_mask == edge_mask) { - break; - } else { - edge_mask = new_edge_mask; + DiGraph result = materialize_digraph_view(g); + // compute transitive reduction + // see https://stackoverflow.com/a/6702198 + std::unordered_set edge_mask = get_edges(g); + for (int j = 0; j < num_nodes; j++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, j)) { + for (int k = 0; k < num_nodes; k++) { + if (has_edge(j, k)) { + has_edge(i, k) = false; + result.remove_edge(DirectedEdge{nodes.at_l(i), nodes.at_l(k)}); + } + } + } } } - return DiGraphView::create(g, edge_mask); + return result; } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc index 34a8eff503..68ef12c49e 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc @@ -38,11 +38,7 @@ void AdjacencyDiGraph::add_edge(DirectedEdge const &e) { } void AdjacencyDiGraph::remove_edge(DirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.src); - auto iter = m.find(e.dst); - if (iter != m.end()) { - m.erase(iter); - } + this->adjacency.at(e.src).erase(e.dst); } std::unordered_set diff --git a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc new file mode 100644 index 0000000000..6f6722f635 --- /dev/null +++ b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc @@ -0,0 +1,58 @@ +#include "utils/graph/instances/unordered_set_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +UnorderedSetUndirectedGraph::UnorderedSetUndirectedGraph() {} + +UnorderedSetUndirectedGraph::UnorderedSetUndirectedGraph( + NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set const &edges) + : node_source(node_source), nodes(nodes), edges(edges) {} + +Node UnorderedSetUndirectedGraph::add_node() { + Node new_node = this->node_source.new_node(); + this->nodes.insert(new_node); + return new_node; +} + +void UnorderedSetUndirectedGraph::add_node_unsafe(Node const &n) { + this->nodes.insert(n); +} + +void UnorderedSetUndirectedGraph::remove_node_unsafe(Node const &n) { + this->nodes.erase(n); +} + +void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) { + assert(contains(this->nodes, e.bigger)); + assert(contains(this->nodes, e.smaller)); + this->edges.insert(e); +} + +void UnorderedSetUndirectedGraph::remove_edge(UndirectedEdge const &e) { + this->edges.erase(e); +} + +std::unordered_set + UnorderedSetUndirectedGraph::query_nodes(NodeQuery const &q) const { + return apply_node_query(q, this->nodes); +} + +std::unordered_set UnorderedSetUndirectedGraph::query_edges( + UndirectedEdgeQuery const &q) const { + return filter(this->edges, + [&](UndirectedEdge const &e) { return matches_edge(q, e); }); +} + +UnorderedSetUndirectedGraph *UnorderedSetUndirectedGraph::clone() const { + return new UnorderedSetUndirectedGraph{ + this->node_source, + this->nodes, + this->edges, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc index 47096d492c..53497a715d 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc @@ -1,7 +1,7 @@ #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_element_counts.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" @@ -10,7 +10,7 @@ namespace FlexFlow { std::unordered_map get_edge_counts(MultiDiGraphView const &g) { return get_element_counts( - transform(as_vector(get_edges(g)), + transform(vector_of(get_edges(g)), [&](MultiDiEdge const &e) { return get_directed_edge(g, e); })); } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index d95a9b9565..1dd5353301 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -2,12 +2,12 @@ #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" #include "utils/containers/get_first.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/node/algorithms.h" @@ -201,7 +201,7 @@ std::unordered_set OpenDataflowGraphView const &dst) { std::unordered_set result; - std::vector src_sink_nodes = as_vector(get_sinks(src)); + std::vector src_sink_nodes = vector_of(get_sinks(src)); std::unordered_set dst_sink_nodes = get_sinks(dst); if (src_sink_nodes.size() != dst_sink_nodes.size()) { @@ -209,7 +209,7 @@ std::unordered_set } std::vector src_unused_graph_inputs = - as_vector(get_unused_open_dataflow_graph_inputs(src)); + vector_of(get_unused_open_dataflow_graph_inputs(src)); std::unordered_set dst_unused_graph_inputs = get_unused_open_dataflow_graph_inputs(dst); diff --git a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc deleted file mode 100644 index 6384bd9159..0000000000 --- a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" -#include "utils/containers/extend.h" - -namespace FlexFlow { - -struct FlattenAST { - void add_flattened_child_to_parent( - IntermediateSpDecompositionTree &parent, - std::variant const &child) { - if (std::holds_alternative(child)) { - parent.children.push_back(child); - return; - } - - IntermediateSpDecompositionTree child_node = - std::get(child); - - if (parent.type == child_node.type) { - extend(parent.children, child_node.children); - } else { - parent.children.push_back(child); - } - } - - std::variant - operator()(IntermediateSpDecompositionTree const &ast_node) { - IntermediateSpDecompositionTree result(ast_node.type, {}); - for (std::variant const &child : - ast_node.children) { - std::variant flattened_child = - flatten_ast(child); - add_flattened_child_to_parent(result, flattened_child); - } - return result; - } - - std::variant - operator()(Node const &ast_node) { - return ast_node; - } -}; - -std::variant flatten_ast( - std::variant const &ast) { - return std::visit(FlattenAST{}, ast); -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc new file mode 100644 index 0000000000..18d1f922c6 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -0,0 +1,43 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" + +namespace FlexFlow { + +BinarySPDecompositionTree + make_series_split(BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{ + make_generic_binary_series_split(lhs.raw_tree, rhs.raw_tree), + }; +} + +BinarySPDecompositionTree + make_parallel_split(BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{ + make_generic_binary_parallel_split(lhs.raw_tree, rhs.raw_tree), + }; +} + +BinarySPDecompositionTree make_leaf_node(Node const &n) { + return BinarySPDecompositionTree{ + make_generic_binary_sp_leaf(n), + }; +} + +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tt) { + return is_binary_sp_tree_left_associative(tt.raw_tree); +} + +bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tt) { + return is_binary_sp_tree_right_associative(tt.raw_tree); +} + +std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { + return get_leaves(tt.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc new file mode 100644 index 0000000000..4cd7206408 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc new file mode 100644 index 0000000000..3a4dbad8ec --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc new file mode 100644 index 0000000000..4ee18af5be --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..71b67acc54 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc new file mode 100644 index 0000000000..227e5bd79c --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc new file mode 100644 index 0000000000..1618128226 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..05ec6b5925 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc new file mode 100644 index 0000000000..f168ba1e2f --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc new file mode 100644 index 0000000000..75c472c435 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc new file mode 100644 index 0000000000..3da024743c --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..8fe9397003 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..d202f55964 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc new file mode 100644 index 0000000000..b569ff9265 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc new file mode 100644 index 0000000000..fb1532b3ef --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc new file mode 100644 index 0000000000..3fee45fcf5 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc new file mode 100644 index 0000000000..cabd66cff7 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc new file mode 100644 index 0000000000..25409333f2 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..02e541b7e4 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,75 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/containers/foldl1.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/overload.h" + +namespace FlexFlow { + +BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &nary) { + std::function( + std::variant const &)> + from_series_child; + std::function( + std::variant const &)> + from_parallel_child; + + auto from_node = [](Node const &n) -> GenericBinarySPDecompositionTree { + return GenericBinarySPDecompositionTree{n}; + }; + + auto from_series = + [&](SeriesSplit const &s) -> GenericBinarySPDecompositionTree { + std::vector> children = + transform(s.children, from_series_child); + return foldl1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{accum, x}, + }; + }); + }; + + auto from_parallel = + [&](ParallelSplit const &s) -> GenericBinarySPDecompositionTree { + std::vector> children = + transform(vector_of(s.children), from_parallel_child); + return foldl1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{accum, x}}; + }); + }; + + from_parallel_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return BinarySPDecompositionTree{ + nary.visit>(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc new file mode 100644 index 0000000000..3b8affd16d --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -0,0 +1,12 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" + +namespace FlexFlow { + +SeriesParallelDecomposition + nary_sp_tree_from_binary(BinarySPDecompositionTree const &binary) { + return to_final_ast(from_binary_sp_tree(binary)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..673a4118a6 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,72 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/containers/foldr1.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/overload.h" + +namespace FlexFlow { + +BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &nary) { + std::function( + std::variant const &)> + from_series_child; + std::function( + std::variant const &)> + from_parallel_child; + + auto from_node = [](Node const &n) { + return GenericBinarySPDecompositionTree{n}; + }; + + auto from_series = [&](SeriesSplit const &s) { + std::vector> children = + transform(s.children, from_series_child); + return foldr1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{x, accum}}; + }); + }; + + auto from_parallel = [&](ParallelSplit const &s) { + std::vector> children = + transform(vector_of(s.children), from_parallel_child); + return foldr1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{x, accum}}; + }); + }; + + from_parallel_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return BinarySPDecompositionTree{ + nary.visit>(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc similarity index 62% rename from lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc rename to lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 632f5245db..ab231f256c 100644 --- a/lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -1,23 +1,28 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" #include "utils/containers/get_only.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/parallel_reduction.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_reduction.h" namespace FlexFlow { -std::optional - get_serial_parallel_decomposition(DiGraphView const &g) { +std::optional + get_series_parallel_decomposition(DiGraphView const &g) { + + DiGraphView transitively_reduced = transitive_reduction(g); InverseLineGraphResult inverse_line_graph_result = ({ std::optional maybe_line_graph = - get_inverse_line_graph(g); + get_inverse_line_graph(transitively_reduced); if (!maybe_line_graph.has_value()) { return std::nullopt; } @@ -27,14 +32,11 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); - std::unordered_map> - ttsp_edge_to_sp_tree = map_values( - inverse_line_graph_result.inverse_edge_to_line_node_bidict - .as_unordered_map(), - [](Node const &n) { - return std::variant{n}; - }); + std::unordered_map + ttsp_edge_to_sp_tree = + map_values(inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return make_leaf_node(n); }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -44,11 +46,8 @@ std::optional ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - std::variant new_tree = - IntermediateSpDecompositionTree{ - SplitType::PARALLEL, - {ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)}, - }; + BinarySPDecompositionTree new_tree = make_parallel_split( + ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -63,11 +62,8 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - std::variant new_tree = - IntermediateSpDecompositionTree{ - SplitType::SERIAL, - {ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)}, - }; + BinarySPDecompositionTree new_tree = make_series_split( + ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -83,7 +79,7 @@ std::optional MultiDiEdge e = get_only(get_edges(ttsp)); if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { - return to_final_ast(ttsp_edge_to_sp_tree.at(e)); + return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e)); } } } diff --git a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc b/lib/utils/src/utils/graph/series_parallel/graph_generation.cc similarity index 79% rename from lib/utils/src/utils/graph/serial_parallel/graph_generation.cc rename to lib/utils/src/utils/graph/series_parallel/graph_generation.cc index 4c9eb9d3ef..7070d04c4a 100644 --- a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc +++ b/lib/utils/src/utils/graph/series_parallel/graph_generation.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/graph_generation.h" +#include "utils/graph/series_parallel/graph_generation.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -12,7 +12,7 @@ void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { } } -void serial_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { +void series_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { // TODO(@lockshaw): This function signature is impossible to implement in // general, as there is no guarantee that the graph view ext actually has // source nodes with inputs Either the signature should be changed, or an @@ -22,11 +22,11 @@ void serial_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { NOT_IMPLEMENTED(); } -DataflowGraph serial_composition(DataflowGraphView const &g1, +DataflowGraph series_composition(DataflowGraphView const &g1, DataflowGraphView const &g2) { DataflowGraph g = DataflowGraph::create_copy_of(g1); - serial_extend_unsafe(g, g2); + series_extend_unsafe(g, g2); return g; } @@ -39,8 +39,8 @@ DataflowGraph parallel_composition(DataflowGraphView const &g1, } DataflowGraph dataflow_graph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition) { - // TODO(@lockshaw): see existing concerns about serial_extend_unsafe + SeriesParallelDecomposition const &sp_decomposition) { + // TODO(@lockshaw): see existing concerns about series_extend_unsafe NOT_IMPLEMENTED(); } diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc new file mode 100644 index 0000000000..48c936ec39 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -0,0 +1,84 @@ +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/containers/extend.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +struct FlattenAST { + void add_flattened_child_to_parent( + IntermediateSpDecompositionTree &parent, + std::variant const &child) { + if (std::holds_alternative(child)) { + parent.children.push_back(child); + return; + } + + IntermediateSpDecompositionTree child_node = + std::get(child); + + if (parent.type == child_node.type) { + extend(parent.children, child_node.children); + } else { + parent.children.push_back(child); + } + } + + std::variant + operator()(IntermediateSpDecompositionTree const &ast_node) { + IntermediateSpDecompositionTree result(ast_node.type, {}); + for (std::variant const &child : + ast_node.children) { + std::variant flattened_child = + flatten_ast(child); + add_flattened_child_to_parent(result, flattened_child); + } + return result; + } + + std::variant + operator()(Node const &ast_node) { + return ast_node; + } +}; + +std::variant flatten_ast( + std::variant const &ast) { + return std::visit(FlattenAST{}, ast); +} + +std::variant + from_binary_sp_tree(GenericBinarySPDecompositionTree const &binary) { + return visit>( + binary, + overload{ + [](Node const &n) { return n; }, + [](GenericBinarySeriesSplit const &s) { + return IntermediateSpDecompositionTree{ + SplitType::SERIES, + { + from_binary_sp_tree(get_left_child(s)), + from_binary_sp_tree(get_right_child(s)), + }, + }; + }, + [](GenericBinaryParallelSplit const &p) { + return IntermediateSpDecompositionTree{ + SplitType::PARALLEL, + { + from_binary_sp_tree(get_left_child(p)), + from_binary_sp_tree(get_right_child(p)), + }, + }; + }, + }); +} + +std::variant + from_binary_sp_tree(BinarySPDecompositionTree const &binary) { + return from_binary_sp_tree(binary.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc similarity index 93% rename from lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc rename to lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 30aa10edd7..12a6630bf0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc similarity index 52% rename from lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc rename to lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 666bf40f10..e697533054 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,18 +1,20 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" namespace FlexFlow { struct ToFinalAST { - std::variant + std::variant operator()(IntermediateSpDecompositionTree const &node) { - if (node.type == SplitType::SERIAL) { - return SerialSplit{transform( + if (node.type == SplitType::SERIES) { + return SeriesSplit{transform( node.children, [](std::variant const &s) { return narrow>( @@ -20,54 +22,55 @@ struct ToFinalAST { .value(); })}; } else { - return ParallelSplit{unordered_set_of(transform( + return ParallelSplit{unordered_multiset_of(transform( node.children, [](std::variant const &s) { - return narrow>( + return narrow>( internal_to_final_ast(s)) .value(); }))}; } } - std::variant operator()(Node const &node) { + std::variant operator()(Node const &node) { return node; } }; -std::variant internal_to_final_ast( +std::variant internal_to_final_ast( std::variant const &ast) { return std::visit(ToFinalAST{}, flatten_ast(ast)); } -SerialParallelDecomposition to_final_ast( +SeriesParallelDecomposition to_final_ast( std::variant const &ast) { - return std::visit([](auto &&x) { return SerialParallelDecomposition{x}; }, + return std::visit([](auto &&x) { return SeriesParallelDecomposition{x}; }, internal_to_final_ast(ast)); } -std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { - return sp.visit>( +std::unordered_multiset get_nodes(SeriesParallelDecomposition const &sp) { + return sp.visit>( [](auto &&t) { return get_nodes(t); }); } -std::unordered_set get_nodes(SerialSplit const &serial) { - return set_union(transform( +std::unordered_multiset get_nodes(SeriesSplit const &serial) { + return multiset_union(transform( serial.children, [](std::variant const &child) - -> std::unordered_set { + -> std::unordered_multiset { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(ParallelSplit const ¶llel) { - return set_union(transform( - parallel.children, [](std::variant const &child) { +std::unordered_multiset get_nodes(ParallelSplit const ¶llel) { + return multiset_union(transform( + vector_of(parallel.children), + [](std::variant const &child) { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(Node const &node) { +std::unordered_multiset get_nodes(Node const &node) { return {node}; } diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc similarity index 65% rename from lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc rename to lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc index 8fa42d4b22..0e04a4f904 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc @@ -1,47 +1,47 @@ -#include "utils/graph/serial_parallel/serial_parallel_splits.h" -#include "utils/fmt/unordered_set.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" +#include "utils/fmt/unordered_multiset.h" #include "utils/fmt/variant.h" #include "utils/fmt/vector.h" #include "utils/hash-utils.h" -#include "utils/hash/unordered_set.h" +#include "utils/hash/unordered_multiset.h" #include "utils/hash/vector.h" namespace FlexFlow { -SerialSplit::SerialSplit( +SeriesSplit::SeriesSplit( std::vector> const &children) : children(children) {} -SerialSplit::SerialSplit( +SeriesSplit::SeriesSplit( std::initializer_list> const &children) : children(children) {} -bool SerialSplit::operator==(SerialSplit const &other) const { +bool SeriesSplit::operator==(SeriesSplit const &other) const { return this->tie() == other.tie(); } -bool SerialSplit::operator!=(SerialSplit const &other) const { +bool SeriesSplit::operator!=(SeriesSplit const &other) const { return this->tie() != other.tie(); } -SerialSplit::Tie SerialSplit::tie() const { +SeriesSplit::Tie SeriesSplit::tie() const { return std::tie(this->children); } -std::string format_as(SerialSplit const &split) { - return fmt::format("", split.children); +std::string format_as(SeriesSplit const &split) { + return fmt::format("", split.children); } -std::ostream &operator<<(std::ostream &s, SerialSplit const &split) { +std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { return s << fmt::to_string(split); } ParallelSplit::ParallelSplit( - std::unordered_set> const &children) + std::unordered_multiset> const &children) : children(children) {} ParallelSplit::ParallelSplit( - std::initializer_list> const &children) + std::initializer_list> const &children) : children(children) {} bool ParallelSplit::operator==(ParallelSplit const &other) const { @@ -68,8 +68,8 @@ std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { namespace std { -size_t hash<::FlexFlow::SerialSplit>::operator()( - ::FlexFlow::SerialSplit const &s) const { +size_t hash<::FlexFlow::SeriesSplit>::operator()( + ::FlexFlow::SeriesSplit const &s) const { size_t result = 0; ::FlexFlow::hash_combine(result, s.children); return result; diff --git a/lib/utils/src/utils/graph/serial_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc similarity index 97% rename from lib/utils/src/utils/graph/serial_parallel/series_reduction.cc rename to lib/utils/src/utils/graph/series_parallel/series_reduction.cc index e26f460e0e..7300c93fb0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/require_same.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc new file mode 100644 index 0000000000..8ae825c1ab --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc @@ -0,0 +1,10 @@ +#include "utils/graph/undirected/algorithms/get_edges.h" +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +std::unordered_set get_edges(UndirectedGraphView const &g) { + return g.query_edges(undirected_edge_query_all()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc new file mode 100644 index 0000000000..3c05b9d5d5 --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -0,0 +1,19 @@ +#include "utils/graph/undirected/algorithms/get_neighboring_nodes.h" +#include "utils/containers/vector_of.h" + +namespace FlexFlow { + +std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, + Node const &n) { + std::unordered_set edges = + g.query_edges(UndirectedEdgeQuery{query_set{n}}); + + std::unordered_set result = + set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { + return std::unordered_set{e.bigger, e.smaller}; + })); + result.erase(n); + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc index 5c41eef7da..3cccf1c6eb 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -6,6 +6,10 @@ UndirectedEdgeQuery undirected_edge_query_all() { return UndirectedEdgeQuery{matchall()}; } +bool matches_edge(UndirectedEdgeQuery const &q, UndirectedEdge const &e) { + return includes(q.nodes, e.bigger) && includes(q.nodes, e.smaller); +} + UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, UndirectedEdgeQuery const &rhs) { return UndirectedEdgeQuery{ diff --git a/lib/utils/src/utils/hash/multiset.cc b/lib/utils/src/utils/hash/multiset.cc new file mode 100644 index 0000000000..d84ca7d614 --- /dev/null +++ b/lib/utils/src/utils/hash/multiset.cc @@ -0,0 +1 @@ +#include "utils/hash/multiset.h" diff --git a/lib/utils/src/utils/hash/unordered_multiset.cc b/lib/utils/src/utils/hash/unordered_multiset.cc new file mode 100644 index 0000000000..7f6f73f428 --- /dev/null +++ b/lib/utils/src/utils/hash/unordered_multiset.cc @@ -0,0 +1 @@ +#include "utils/hash/unordered_multiset.h" diff --git a/lib/utils/src/utils/json/check_is_jsonable.cc b/lib/utils/src/utils/json/check_is_jsonable.cc new file mode 100644 index 0000000000..1e78fdb21f --- /dev/null +++ b/lib/utils/src/utils/json/check_is_jsonable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_jsonable.h" diff --git a/lib/utils/src/utils/json/is_json_deserializable.cc b/lib/utils/src/utils/json/is_json_deserializable.cc new file mode 100644 index 0000000000..17df41433d --- /dev/null +++ b/lib/utils/src/utils/json/is_json_deserializable.cc @@ -0,0 +1 @@ +#include "utils/json/is_json_deserializable.h" diff --git a/lib/utils/src/utils/json/is_json_serializable.cc b/lib/utils/src/utils/json/is_json_serializable.cc new file mode 100644 index 0000000000..883ee9f51a --- /dev/null +++ b/lib/utils/src/utils/json/is_json_serializable.cc @@ -0,0 +1 @@ +#include "utils/json/is_json_serializable.h" diff --git a/lib/utils/src/utils/json/is_jsonable.cc b/lib/utils/src/utils/json/is_jsonable.cc new file mode 100644 index 0000000000..3f819f8556 --- /dev/null +++ b/lib/utils/src/utils/json/is_jsonable.cc @@ -0,0 +1 @@ +#include "utils/json/is_jsonable.h" diff --git a/lib/utils/src/utils/json/optional.cc b/lib/utils/src/utils/json/optional.cc new file mode 100644 index 0000000000..c8f0fd2e3c --- /dev/null +++ b/lib/utils/src/utils/json/optional.cc @@ -0,0 +1 @@ +#include "utils/json/optional.h" diff --git a/lib/utils/src/utils/rapidcheck/optional.cc b/lib/utils/src/utils/rapidcheck/optional.cc new file mode 100644 index 0000000000..6d62532e7e --- /dev/null +++ b/lib/utils/src/utils/rapidcheck/optional.cc @@ -0,0 +1 @@ +#include "utils/rapidcheck/optional.h" diff --git a/lib/utils/test/common/include/test/utils/all.h b/lib/utils/test/common/include/test/utils/all.h deleted file mode 100644 index ced1c9ce38..0000000000 --- a/lib/utils/test/common/include/test/utils/all.h +++ /dev/null @@ -1,2 +0,0 @@ -#include "test/utils/doctest.h" -#include "test/utils/rapidcheck.h" diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest/check_without_stringify.h similarity index 100% rename from lib/utils/test/common/include/test/utils/doctest.h rename to lib/utils/test/common/include/test/utils/doctest/check_without_stringify.h diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h b/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h new file mode 100644 index 0000000000..8333ac4777 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_EXPECTED_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_EXPECTED_H + +#include "utils/fmt/expected.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(tl::expected const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/map.h b/lib/utils/test/common/include/test/utils/doctest/fmt/map.h new file mode 100644 index 0000000000..d20dbe6943 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MAP_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MAP_H + +#include "utils/fmt/map.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h b/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h new file mode 100644 index 0000000000..b26eee28ba --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MULTISET_H + +#include "utils/fmt/multiset.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h new file mode 100644 index 0000000000..519cde7d74 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_OPTIONAL_H + +#include "utils/fmt/optional.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::optional const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h b/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h new file mode 100644 index 0000000000..db0ed24f13 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_PAIR_H + +#include "utils/fmt/pair.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::pair const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/set.h b/lib/utils/test/common/include/test/utils/doctest/fmt/set.h new file mode 100644 index 0000000000..3dd386645c --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_SET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_SET_H + +#include "utils/fmt/set.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h new file mode 100644 index 0000000000..4fd5d15009 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MAP_H + +#include "utils/fmt/unordered_map.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h new file mode 100644 index 0000000000..94dae42239 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MULTISET_H + +#include "utils/fmt/unordered_multiset.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h new file mode 100644 index 0000000000..441590365d --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_SET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_SET_H + +#include "utils/fmt/unordered_set.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h b/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h new file mode 100644 index 0000000000..c30862274a --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VARIANT_H + +#include "utils/fmt/variant.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::variant const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h b/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h new file mode 100644 index 0000000000..56198a7558 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VECTOR_H + +#include "utils/fmt/vector.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::vector const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/src/common.cc b/lib/utils/test/common/src/common.cc deleted file mode 100644 index 51e981b1f5..0000000000 --- a/lib/utils/test/common/src/common.cc +++ /dev/null @@ -1 +0,0 @@ -#include "test/utils/all.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc new file mode 100644 index 0000000000..1cff2195db --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/expected.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc new file mode 100644 index 0000000000..976e65cfca --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/map.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc new file mode 100644 index 0000000000..9c5b2f4d1e --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/multiset.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc new file mode 100644 index 0000000000..8a3f7f158e --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/optional.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc new file mode 100644 index 0000000000..106fb1c900 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/pair.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc new file mode 100644 index 0000000000..9ec70698bc --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/set.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc new file mode 100644 index 0000000000..b893e632ed --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_map.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc new file mode 100644 index 0000000000..55d2e69056 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_multiset.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc new file mode 100644 index 0000000000..13ad811e63 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_set.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc new file mode 100644 index 0000000000..b6cc4f54e4 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/variant.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc new file mode 100644 index 0000000000..0102cd86da --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/vector.h" diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index a1dd75504e..44f602f3bc 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -1,10 +1,10 @@ -#include "test/utils/doctest.h" #include "utils/graph/algorithms.h" #include "utils/graph/construction.h" #include "utils/graph/hashmap_undirected_graph.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/undirected.h" #include +#include #include #include #include diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index af7792dc6d..dca500ced5 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/containers.h" +#include #include #include #include @@ -275,9 +275,9 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == std::vector({2, 4, 6})); } - TEST_CASE("as_vector") { + TEST_CASE("vector_of") { std::unordered_set s = {1, 2, 3}; - std::vector result = as_vector(s); + std::vector result = vector_of(s); CHECK(result == std::vector({3, 2, 1})); } diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/test_deduplicated_priority_queue.cc index 66cfd395bc..048e95acb7 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/test_deduplicated_priority_queue.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/deduplicated_priority_queue.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DeduplicatedPriorityQueue push and pop") { diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/test_disjoint_set.cc index 80fcf87d6b..65037be3dd 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/test_disjoint_set.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/disjoint_set.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/test_dot_file.cc index ed4c32bb1c..e409572511 100644 --- a/lib/utils/test/src/test_dot_file.cc +++ b/lib/utils/test/src/test_dot_file.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/dot_file.h" +#include #include TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/test_format.cc index eeed2eae81..f0d396a123 100644 --- a/lib/utils/test/src/test_format.cc +++ b/lib/utils/test/src/test_format.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/record_formatter.h" +#include std::string formatRecord(RecordFormatter const &formatter) { std::ostringstream oss; diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc index b38c43fe30..decf405e7a 100644 --- a/lib/utils/test/src/test_hash.cc +++ b/lib/utils/test/src/test_hash.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/hash-utils.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc index 90e1bb2187..cc7ac1de32 100644 --- a/lib/utils/test/src/test_multidigraph.cc +++ b/lib/utils/test/src/test_multidigraph.cc @@ -1,7 +1,7 @@ -#include "test/utils/doctest.h" #include "utils/graph/adjacency_multidigraph.h" #include "utils/graph/multidiedge.h" #include "utils/graph/multidigraph_interfaces.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/test_random_utils.cc index 88a566a198..2b816eea4f 100644 --- a/lib/utils/test/src/test_random_utils.cc +++ b/lib/utils/test/src/test_random_utils.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "utils/random_utils.h" #include +#include void checkProbabilities(std::vector const &counts, int numIterations, diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/test_sequence.cc index ee72febe05..a758476fd9 100644 --- a/lib/utils/test/src/test_sequence.cc +++ b/lib/utils/test/src/test_sequence.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/sequence.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/test_stack_map.cc index 21c1b07d1b..f117820c5d 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/test_stack_map.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/stack_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/test_stack_string.cc index a044f85fe3..b89e3277cd 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/test_stack_string.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/stack_string.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 1af43b6993..577e61092c 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/stack_vector.h" +#include #include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/test_tuple.cc index 31308dec2c..96171510a7 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/test_tuple.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/tuple.h" +#include #include #include diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/test_type_index.cc index b2d8aea848..e7ce12346a 100644 --- a/lib/utils/test/src/test_type_index.cc +++ b/lib/utils/test/src/test_type_index.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/type_index.h" +#include #include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index 33b102bd3b..ea519478d3 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -1,7 +1,8 @@ -#include "test/utils/all.h" +#include "test/utils/rapidcheck.h" #include "test/utils/rapidcheck/visitable.h" #include "utils/graph/hashmap_undirected_graph.h" #include "utils/graph/undirected.h" +#include /* namespace rc { */ diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 98b28a48e9..0bd01b8dfe 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/variant.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/test_vector.cc index 4bdc724dd8..c6eb0828b8 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/test_vector.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/vector.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat function") { diff --git a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc index 6e3ac8c155..b5a373e5c9 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc @@ -1,7 +1,7 @@ #include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" -#include "utils/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index 5c2ffd5bba..fed655013f 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -1,6 +1,8 @@ #include "utils/bidict/bidict.h" -#include "test/utils/doctest.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/check_without_stringify.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/vector.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc index 2eb8f869f9..49fed81b29 100644 --- a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc @@ -1,6 +1,6 @@ #include "utils/bidict/try_merge_nondisjoint_bidicts.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/cli/cli_get_help_message.cc b/lib/utils/test/src/utils/cli/cli_get_help_message.cc new file mode 100644 index 0000000000..b3ee4d3318 --- /dev/null +++ b/lib/utils/test/src/utils/cli/cli_get_help_message.cc @@ -0,0 +1,519 @@ +#include "utils/cli/cli_get_help_message.h" +#include "utils/join_strings.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cli_get_help_message(std::string, CLISpec)") { + std::string program_name = "prog_name"; + + SUBCASE("no flags or positional arguments") { + CLISpec cli = CLISpec{ + {}, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name\n"); + + CHECK(result == correct); + } + + SUBCASE("no flags") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "pos-arg-1", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name pos-arg-1\n" + "\n" + "positional arguments:\n" + " pos-arg-1\n"); + + CHECK(result == correct); + } + + SUBCASE("no positional arguments") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag-1", + 'f', + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f]\n" + "\n" + "options:\n" + " -f, --flag-1\n"); + + CHECK(result == correct); + } + + SUBCASE("flag formatting") { + SUBCASE("flag with shortname") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + 'f', + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f]\n" + "\n" + "options:\n" + " -f, --flag\n"); + + CHECK(result == correct); + } + + SUBCASE("flag without shortname") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + std::nullopt, + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [--flag]\n" + "\n" + "options:\n" + " --flag\n"); + + CHECK(result == correct); + } + + SUBCASE("flags are displayed in provided order") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [--flag2] [--flag1]\n" + "\n" + "options:\n" + " --flag2\n" + " --flag1\n"); + + CHECK(result == correct); + } + } + + SUBCASE("positional argument formatting") { + SUBCASE("without choices") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name posarg\n" + "\n" + "positional arguments:\n" + " posarg\n"); + + CHECK(result == correct); + } + + SUBCASE("with choices") { + SUBCASE("choices are not empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {red,blue,green}\n" + "\n" + "positional arguments:\n" + " {red,blue,green}\n"); + + CHECK(result == correct); + } + + SUBCASE("choices are empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{}, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {}\n" + "\n" + "positional arguments:\n" + " {}\n"); + + CHECK(result == correct); + } + } + + SUBCASE("are displayed in provided order") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg1", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name posarg2 posarg1\n" + "\n" + "positional arguments:\n" + " posarg2\n" + " posarg1\n"); + + CHECK(result == correct); + } + } + + SUBCASE("flag and positional argument alignment") { + SUBCASE("flags are longer") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + '1', + "flag1 description", + }, + CLIFlagSpec{ + "flag2-is-long", + std::nullopt, + "flag2-is-long description", + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + "help text for posarg", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-1] [--flag2-is-long] posarg\n" + "\n" + "positional arguments:\n" + " posarg help text for posarg\n" + "\n" + "options:\n" + " -1, --flag1 flag1 description\n" + " --flag2-is-long flag2-is-long description\n"); + + CHECK(result == correct); + } + + SUBCASE("pos args are longer") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + '1', + "flag1 description", + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1-is-very-long", + std::nullopt, + "help text for posarg1-is-very-long", + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + "help text for posarg2", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-1] posarg1-is-very-long posarg2\n" + "\n" + "positional arguments:\n" + " posarg1-is-very-long help text for posarg1-is-very-long\n" + " posarg2 help text for posarg2\n" + "\n" + "options:\n" + " -1, --flag1 flag1 description\n"); + + CHECK(result == correct); + } + + SUBCASE("line break behavior") { + SUBCASE("line breaks max out other argument alignments") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + 'f', + "flag help text", + }, + }, + { + CLIPositionalArgumentSpec{ + "abcdefghijklmnopqrstuvwxyz0123456789", + std::nullopt, + "long arg help text", + }, + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + "posarg help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f] " + "abcdefghijklmnopqrstuvwxyz0123456789 posarg\n" + "\n" + "positional arguments:\n" + " abcdefghijklmnopqrstuvwxyz0123456789\n" + " long arg help text\n" + " posarg posarg help text\n" + "\n" + "options:\n" + " -f, --flag flag help text\n"); + + CHECK(result == correct); + } + SUBCASE("positional argument line break behavior") { + SUBCASE("positional arguments cause a line break at or above " + "formatted-length 22") { + std::string arg_name = "aaaaaaaaaaaaaaaaaaaaaa"; + REQUIRE(arg_name.size() == 22); + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + arg_name, + std::nullopt, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name aaaaaaaaaaaaaaaaaaaaaa\n" + "\n" + "positional arguments:\n" + " aaaaaaaaaaaaaaaaaaaaaa\n" + " help text\n"); + + CHECK(result == correct); + } + + SUBCASE("positional arguments do not cause a line break below " + "formatted-length 22") { + std::string arg_name = "aaaaaaaaaaaaaaaaaaaaa"; + REQUIRE(arg_name.size() == 21); + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + arg_name, + std::nullopt, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name aaaaaaaaaaaaaaaaaaaaa\n" + "\n" + "positional arguments:\n" + " aaaaaaaaaaaaaaaaaaaaa\n" + " help text\n"); + } + } + + SUBCASE("flag line break behavior") { + SUBCASE("flags cause a line break at or above formatted-length 21") { + std::string arg_name = "bbbbbbbbbbbbbbb"; + { + std::string formatted = "-b, --" + arg_name; + REQUIRE(formatted.size() == 21); + } + + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + arg_name, + 'b', + "flag description", + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-b]\n" + "\n" + "options:\n" + " -b, --bbbbbbbbbbbbbbb\n" + " flag description\n"); + + CHECK(result == correct); + } + + SUBCASE("flags do not cause a line break below formatted-length 21") { + std::string arg_name = "bbbbbbbbbbbbbb"; + { + std::string formatted = "-b, --" + arg_name; + REQUIRE(formatted.size() == 20); + } + + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + arg_name, + 'b', + "flag description", + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-b]\n" + "\n" + "options:\n" + " -b, --bbbbbbbbbbbbbb flag description\n"); + + CHECK(result == correct); + } + } + + SUBCASE("choice line breakpoint formatting") { + SUBCASE( + "choices cause a line break at or above formatted-length 21") { + std::vector choices = { + "a", "b", "c", "d", "e", "fffffffff"}; + { + std::string formatted_choices = + "{" + join_strings(choices, ",") + "}"; + REQUIRE(formatted_choices.size() == 21); + } + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + choices, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {a,b,c,d,e,fffffffff}\n" + "\n" + "positional arguments:\n" + " {a,b,c,d,e,fffffffff}\n" + " help text\n"); + + CHECK(result == correct); + } + + SUBCASE( + "choices do not cause a line break below formatted-length 21") { + std::vector choices = { + "a", "b", "c", "d", "e", "ffffffff"}; + { + std::string formatted_choices = + "{" + join_strings(choices, ",") + "}"; + REQUIRE(formatted_choices.size() == 20); + } + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + choices, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {a,b,c,d,e,ffffffff}\n" + "\n" + "positional arguments:\n" + " {a,b,c,d,e,ffffffff} help text\n"); + + CHECK(result == correct); + } + } + } + } + } +} diff --git a/lib/utils/test/src/utils/cli/cli_parse.cc b/lib/utils/test/src/utils/cli/cli_parse.cc new file mode 100644 index 0000000000..40dea86ae0 --- /dev/null +++ b/lib/utils/test/src/utils/cli/cli_parse.cc @@ -0,0 +1,477 @@ +#include "utils/cli/cli_parse.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cli_parse_flag(CLISpec, std::string)") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + '2', + std::nullopt, + }, + }, + {}, + }; + + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + + SUBCASE("correctly parses short flag") { + std::string input = "-2"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = key_flag2; + + CHECK(result == correct); + } + + SUBCASE("correctly parses long flag") { + std::string input = "--flag1"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = key_flag1; + + CHECK(result == correct); + } + + SUBCASE("fails on unknown flag") { + std::string input = "--not-real"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = + tl::unexpected("Encountered unknown flag --not-real"); + + CHECK(result == correct); + } + + SUBCASE("fails on non-flag") { + std::string input = "-flag1"; + + std::optional result = + optional_from_expected(cli_parse_flag(cli, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + TEST_CASE("cli_parse(CLISpec, std::vector)") { + SUBCASE("works even if cli is empty") { + CLISpec cli = CLISpec{{}, {}}; + std::vector inputs = {"prog_name"}; + + tl::expected result = cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, {}}; + + CHECK(result == correct); + } + + SUBCASE("flag parsing") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + '2', + std::nullopt, + }, + }, + {}, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + + SUBCASE("parses flags in any order") { + std::vector inputs = {"prog_name", "-2", "--flag1"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine if some are not present") { + std::vector inputs = {"prog_name", "-2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine if none are present") { + std::vector inputs = {"prog_name"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, false}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine even if the program name is a flag") { + std::vector inputs = {"--flag1", "-2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + } + + SUBCASE("positional argument parsing") { + SUBCASE("without choices") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::nullopt, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + SUBCASE("can parse multiple positional arguments") { + std::vector inputs = {"prog_name", "hello", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello"}, + {key_posarg2, "world"}, + }}; + + CHECK(result == correct); + } + + SUBCASE("requires all positional arguments to be present") { + std::vector inputs = {"prog_name", "hello"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Not enough positional arguments: found 1, expected 2"); + + CHECK(result == correct); + } + + SUBCASE("requires no extra positional arguments to be present") { + std::vector inputs = { + "prog_name", "hello", "there", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + tl::unexpected("Too many positional arguments: expected 2"); + + CHECK(result == correct); + } + + SUBCASE("allows arguments to contain spaces") { + std::vector inputs = { + "prog_name", "hello there", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello there"}, + {key_posarg2, "world"}, + }}; + + CHECK(result == correct); + } + + SUBCASE("allows arguments to be empty") { + std::vector inputs = {"prog_name", "hello", ""}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello"}, + {key_posarg2, ""}, + }}; + + CHECK(result == correct); + } + } + + SUBCASE("with choices") { + SUBCASE("choices is non-empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + }, + }; + + CLIPositionalArgumentKey key_posarg = CLIPositionalArgumentKey{0}; + + SUBCASE( + "succeeds if a positional argument is set to a valid choice") { + std::vector inputs = {"prog_name", "blue"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + {}, + { + {key_posarg, "red"}, + }, + }; + } + + SUBCASE( + "fails if a positional argument is set to an invalid choice") { + std::vector inputs = {"prog_name", " red"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Invalid option for positional argument \"posarg\": \" red\""); + + CHECK(result == correct); + } + } + + SUBCASE("if choices is empty, rejects everything") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{}, + std::nullopt, + }, + }, + }; + + std::vector inputs = {"prog_name", ""}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Invalid option for positional argument \"posarg\": \"\""); + + CHECK(result == correct); + } + } + } + + SUBCASE("correctly differentiates mixed arguments/flags") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + 'f', + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag3", + 'a', + std::nullopt, + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + CLIFlagKey key_flag3 = CLIFlagKey{2}; + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + SUBCASE("works if flags are before positional arguments") { + std::vector inputs = { + "prog_name", "-f", "--flag3", "red", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("works if flags are interspersed") { + std::vector inputs = { + "prog_name", "red", "-f", "world", "--flag3"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("detects if posargs are missing instead of treating flags as " + "posarg values") { + std::vector inputs = {"prog_name", "-f", "red", "--flag2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Not enough positional arguments: found 1, expected 2"); + + CHECK(result == correct); + } + } + } + + TEST_CASE("cli_parse(CLISpec, int argc, char const * const *argv)") { + // most cases are checked in the other overload, + // i.e., cli_parse(CLISpec, std::vector), + // so here we just throw in a single check to make sure + // nothing has unexpectedly gone wrong + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + 'f', + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag3", + 'a', + std::nullopt, + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + CLIFlagKey key_flag3 = CLIFlagKey{2}; + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + int argc = 5; + char const *argv[] = {"prog_name", "red", "-f", "world", "--flag3"}; + + tl::expected result = + cli_parse(cli, argc, argv); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/contains_key.cc b/lib/utils/test/src/utils/containers/contains_key.cc index acc6551cd4..da099113a6 100644 --- a/lib/utils/test/src/utils/containers/contains_key.cc +++ b/lib/utils/test/src/utils/containers/contains_key.cc @@ -1,8 +1,11 @@ #include "utils/containers/contains_key.h" -#include "test/utils/doctest.h" +#include #include +#include #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("contains_key(std::unordered_map, K)") { std::unordered_map m = { diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index 2be5f1ef93..c6ce9942e9 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -1,8 +1,12 @@ #include "utils/containers/enumerate.h" -#include "utils/containers/as_vector.h" -#include "utils/fmt/map.h" -#include "utils/fmt/pair.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/containers/keys.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" #include #include @@ -25,7 +29,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("check iteration order") { std::vector> iterated_result = - as_vector(result); + vector_of(result); std::vector> correct_iteration_order = { {0, "zero"}, {1, "one"}, @@ -46,5 +50,17 @@ TEST_SUITE(FF_TEST_SUITE) { {2, "two"}, {3, "three"}, }; + + std::map result = enumerate(input); + + std::unordered_set result_keys = keys(correct); + std::unordered_set result_values = + unordered_set_of(values(correct)); + + std::unordered_set correct_keys = {0, 1, 2, 3}; + std::unordered_set correct_values = input; + + CHECK(result_keys == correct_keys); + CHECK(result_values == correct_values); } } diff --git a/lib/utils/test/src/utils/containers/extend.cc b/lib/utils/test/src/utils/containers/extend.cc index e0d156a3fc..ef2a67725c 100644 --- a/lib/utils/test/src/utils/containers/extend.cc +++ b/lib/utils/test/src/utils/containers/extend.cc @@ -1,6 +1,6 @@ #include "utils/containers/extend.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filter.cc b/lib/utils/test/src/utils/containers/filter.cc index da459094ef..770ad40375 100644 --- a/lib/utils/test/src/utils/containers/filter.cc +++ b/lib/utils/test/src/utils/containers/filter.cc @@ -1,10 +1,10 @@ #include "utils/containers/filter.h" -#include "test/utils/all.h" -#include "utils/fmt/map.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_keys.cc b/lib/utils/test/src/utils/containers/filtermap_keys.cc index 758264627b..582e94392b 100644 --- a/lib/utils/test/src/utils/containers/filtermap_keys.cc +++ b/lib/utils/test/src/utils/containers/filtermap_keys.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtermap_keys.h" -#include "test/utils/doctest.h" -#include "utils/fmt/map.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_values.cc b/lib/utils/test/src/utils/containers/filtermap_values.cc index d2b6ddd220..8db6d6a964 100644 --- a/lib/utils/test/src/utils/containers/filtermap_values.cc +++ b/lib/utils/test/src/utils/containers/filtermap_values.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtermap_values.h" -#include "test/utils/doctest.h" -#include "utils/fmt/map.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtrans.cc b/lib/utils/test/src/utils/containers/filtrans.cc index b8bb832b06..cd1c2f896c 100644 --- a/lib/utils/test/src/utils/containers/filtrans.cc +++ b/lib/utils/test/src/utils/containers/filtrans.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtrans.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/foldl1.cc b/lib/utils/test/src/utils/containers/foldl1.cc new file mode 100644 index 0000000000..597aa5e109 --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldl1.cc @@ -0,0 +1,27 @@ +#include "utils/containers/foldl1.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldl1(std::vector, F)") { + auto concat = [](std::string const &accum, std::string const &s) { + return accum + s; + }; + + SUBCASE("empty input") { + std::vector input = {}; + CHECK_THROWS(foldl1(input, concat)); + } + + SUBCASE("non-empty input") { + std::vector input = {"a s", "tr", "ing"}; + + std::string result = foldl1(input, concat); + + std::string correct = "a string"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/foldr1.cc b/lib/utils/test/src/utils/containers/foldr1.cc new file mode 100644 index 0000000000..3c9d9b66ae --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldr1.cc @@ -0,0 +1,27 @@ +#include "utils/containers/foldr1.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldr1(std::vector, F)") { + auto concat = [](std::string const &accum, std::string const &s) { + return accum + s; + }; + + SUBCASE("empty input") { + std::vector input = {}; + CHECK_THROWS(foldr1(input, concat)); + } + + SUBCASE("non-empty input") { + std::vector input = {"ing", "tr", "a s"}; + + std::string result = foldr1(input, concat); + + std::string correct = "a string"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_permutations.cc b/lib/utils/test/src/utils/containers/get_all_permutations.cc index 5f22266809..cc5edb4075 100644 --- a/lib/utils/test/src/utils/containers/get_all_permutations.cc +++ b/lib/utils/test/src/utils/containers/get_all_permutations.cc @@ -1,8 +1,7 @@ #include "utils/containers/get_all_permutations.h" -#include "utils/containers/as_vector.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/vector.h" #include "utils/hash/vector.h" #include diff --git a/lib/utils/test/src/utils/containers/get_element_counts.cc b/lib/utils/test/src/utils/containers/get_element_counts.cc index 11e2ef7e05..8fc87dba90 100644 --- a/lib/utils/test/src/utils/containers/get_element_counts.cc +++ b/lib/utils/test/src/utils/containers/get_element_counts.cc @@ -1,5 +1,5 @@ #include "utils/containers/get_element_counts.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/inplace_filter.cc b/lib/utils/test/src/utils/containers/inplace_filter.cc index 7ef9d73339..ac430279b0 100644 --- a/lib/utils/test/src/utils/containers/inplace_filter.cc +++ b/lib/utils/test/src/utils/containers/inplace_filter.cc @@ -1,10 +1,11 @@ #include "utils/containers/inplace_filter.h" -#include "test/utils/all.h" -#include "utils/fmt/map.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/intersection.cc b/lib/utils/test/src/utils/containers/intersection.cc index ac9acf5e2b..52de6ee6d3 100644 --- a/lib/utils/test/src/utils/containers/intersection.cc +++ b/lib/utils/test/src/utils/containers/intersection.cc @@ -1,6 +1,6 @@ #include "utils/containers/intersection.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/maximum.cc b/lib/utils/test/src/utils/containers/maximum.cc new file mode 100644 index 0000000000..71e7395805 --- /dev/null +++ b/lib/utils/test/src/utils/containers/maximum.cc @@ -0,0 +1,60 @@ +#include "utils/containers/maximum.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("maximum(T)", + T, + std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { + + SUBCASE("input is empty") { + T input = {}; + + std::optional result = maximum(input); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input does not have duplicates") { + T input = {1, 3, 2}; + + std::optional result = maximum(input); + std::optional correct = 3; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + T input = {1, 2, 2, 0}; + + std::optional result = maximum(input); + std::optional correct = 2; + + CHECK(result == correct); + } + } + + TEST_CASE("maximum(std::vector)") { + std::vector input = {"hello", "world"}; + + std::optional result = maximum(input); + std::optional correct = "world"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/multiset_union.cc b/lib/utils/test/src/utils/containers/multiset_union.cc new file mode 100644 index 0000000000..8c40bf55ab --- /dev/null +++ b/lib/utils/test/src/utils/containers/multiset_union.cc @@ -0,0 +1,29 @@ +#include "utils/containers/multiset_union.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("multiset_union(std::unordered_multiset, " + "std::unordered_multiset)") { + std::unordered_multiset input_lhs = {1, 2, 2, 3}; + std::unordered_multiset input_rhs = {1, 2, 5}; + + std::unordered_multiset result = multiset_union(input_lhs, input_rhs); + std::unordered_multiset correct = {1, 1, 2, 2, 2, 3, 5}; + + CHECK(result == correct); + } + + TEST_CASE("multiset_union(std::multiset, std::multiset)") { + std::multiset input_lhs = {1, 2, 2, 3}; + std::multiset input_rhs = {1, 2, 5}; + + std::multiset result = multiset_union(input_lhs, input_rhs); + std::multiset correct = {1, 1, 2, 2, 2, 3, 5}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/repeat.cc b/lib/utils/test/src/utils/containers/repeat.cc index 50e4b3e7c5..d8ffe76a64 100644 --- a/lib/utils/test/src/utils/containers/repeat.cc +++ b/lib/utils/test/src/utils/containers/repeat.cc @@ -1,5 +1,5 @@ #include "utils/containers/repeat.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/require_no_duplicates.cc b/lib/utils/test/src/utils/containers/require_no_duplicates.cc new file mode 100644 index 0000000000..67733d791a --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_no_duplicates.cc @@ -0,0 +1,62 @@ +#include "utils/containers/require_no_duplicates.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("require_no_duplicates(std::unordered_multiset)") { + SUBCASE("empty") { + std::unordered_multiset input = {}; + + std::unordered_set result = require_no_duplicates(input); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + std::unordered_multiset input = {1, 2, 2}; + + CHECK_THROWS(require_no_duplicates(input)); + } + + SUBCASE("input does not have duplicates") { + std::unordered_multiset input = {1, 2, 4}; + + std::unordered_set result = require_no_duplicates(input); + std::unordered_set correct = {1, 2, 4}; + + CHECK(result == correct); + } + } + + TEST_CASE("require_no_duplicates(std::multiset)") { + SUBCASE("empty") { + std::multiset input = {}; + + std::set result = require_no_duplicates(input); + std::set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + std::multiset input = {1, 2, 2}; + + CHECK_THROWS(require_no_duplicates(input)); + } + + SUBCASE("input does not have duplicates") { + std::multiset input = {1, 2, 4}; + + std::set result = require_no_duplicates(input); + std::set correct = {1, 2, 4}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/reversed.cc b/lib/utils/test/src/utils/containers/reversed.cc new file mode 100644 index 0000000000..834a497152 --- /dev/null +++ b/lib/utils/test/src/utils/containers/reversed.cc @@ -0,0 +1,27 @@ +#include "utils/containers/reversed.h" +#include "test/utils/doctest/fmt/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("reversed(std::vector)") { + SUBCASE("non-empty input") { + std::vector input = {1, 2, 3, 2}; + + std::vector result = reversed(input); + std::vector correct = {2, 3, 2, 1}; + + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + + std::vector result = reversed(input); + std::vector correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/to_uppercase.cc b/lib/utils/test/src/utils/containers/to_uppercase.cc new file mode 100644 index 0000000000..9729307304 --- /dev/null +++ b/lib/utils/test/src/utils/containers/to_uppercase.cc @@ -0,0 +1,15 @@ +#include "utils/containers/to_uppercase.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("to_uppercase(std::string)") { + std::string input = "Hello World"; + + std::string result = to_uppercase(input); + std::string correct = "HELLO WORLD"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/transform.cc b/lib/utils/test/src/utils/containers/transform.cc index 916bc20928..3122c67117 100644 --- a/lib/utils/test/src/utils/containers/transform.cc +++ b/lib/utils/test/src/utils/containers/transform.cc @@ -1,7 +1,7 @@ #include "utils/containers/transform.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc index 6aeab4ae6e..b8a7a85f74 100644 --- a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc +++ b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc @@ -1,7 +1,7 @@ #include "utils/containers/try_merge_nondisjoint_unordered_maps.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc index 0ab0ef1446..becb7fdce0 100644 --- a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include #include diff --git a/lib/utils/test/src/utils/containers/unordered_set_of.cc b/lib/utils/test/src/utils/containers/unordered_set_of.cc index d42b41dd50..b8ca1d1797 100644 --- a/lib/utils/test/src/utils/containers/unordered_set_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_set_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/unordered_set_of.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include #include diff --git a/lib/utils/test/src/utils/containers/vector_of.cc b/lib/utils/test/src/utils/containers/vector_of.cc new file mode 100644 index 0000000000..8b9353e1b0 --- /dev/null +++ b/lib/utils/test/src/utils/containers/vector_of.cc @@ -0,0 +1,17 @@ +#include "utils/containers/vector_of.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("vector_of(std::set)") { + std::set input = {2, 3, 1, 4}; + + std::vector result = vector_of(input); + std::vector correct = {1, 2, 3, 4}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/without_order.cc b/lib/utils/test/src/utils/containers/without_order.cc index 939c6ff108..b4c8663b14 100644 --- a/lib/utils/test/src/utils/containers/without_order.cc +++ b/lib/utils/test/src/utils/containers/without_order.cc @@ -1,5 +1,5 @@ #include "utils/containers/without_order.h" -#include "utils/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include #include diff --git a/lib/utils/test/src/utils/expected.cc b/lib/utils/test/src/utils/expected.cc index 14679e0d13..3e5de13d49 100644 --- a/lib/utils/test/src/utils/expected.cc +++ b/lib/utils/test/src/utils/expected.cc @@ -1,6 +1,6 @@ #include "utils/expected.h" -#include "utils/fmt/expected.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/expected.cc b/lib/utils/test/src/utils/fmt/expected.cc index fb39732761..48df8634db 100644 --- a/lib/utils/test/src/utils/fmt/expected.cc +++ b/lib/utils/test/src/utils/fmt/expected.cc @@ -1,5 +1,6 @@ #include "utils/fmt/expected.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include using namespace ::FlexFlow; @@ -19,24 +20,4 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } - - TEST_CASE("doctest::toString(tl::expected)") { - SUBCASE("has expected") { - tl::expected input = 3; - - doctest::String result = doctest::toString(input); - doctest::String correct = "expected(3)"; - - CHECK(result == correct); - } - - SUBCASE("has unexpected") { - tl::expected input = tl::make_unexpected("error"); - - doctest::String result = doctest::toString(input); - doctest::String correct = "unexpected(error)"; - - CHECK(result == correct); - } - } } diff --git a/lib/utils/test/src/utils/fmt/map.cc b/lib/utils/test/src/utils/fmt/map.cc index b65b4791ea..19f3a7d5cf 100644 --- a/lib/utils/test/src/utils/fmt/map.cc +++ b/lib/utils/test/src/utils/fmt/map.cc @@ -1,5 +1,5 @@ #include "utils/fmt/map.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/optional.cc b/lib/utils/test/src/utils/fmt/optional.cc index e7815a26ac..1cd79da747 100644 --- a/lib/utils/test/src/utils/fmt/optional.cc +++ b/lib/utils/test/src/utils/fmt/optional.cc @@ -1,5 +1,5 @@ #include "utils/fmt/optional.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/pair.cc b/lib/utils/test/src/utils/fmt/pair.cc index 3d7cc78756..e848eb08c7 100644 --- a/lib/utils/test/src/utils/fmt/pair.cc +++ b/lib/utils/test/src/utils/fmt/pair.cc @@ -1,5 +1,5 @@ #include "utils/fmt/pair.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/set.cc b/lib/utils/test/src/utils/fmt/set.cc index 66824f2b2a..e317954b02 100644 --- a/lib/utils/test/src/utils/fmt/set.cc +++ b/lib/utils/test/src/utils/fmt/set.cc @@ -1,5 +1,5 @@ #include "utils/fmt/set.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/unordered_map.cc b/lib/utils/test/src/utils/fmt/unordered_map.cc index 99752d73f4..c980bc1e52 100644 --- a/lib/utils/test/src/utils/fmt/unordered_map.cc +++ b/lib/utils/test/src/utils/fmt/unordered_map.cc @@ -1,6 +1,7 @@ #include "utils/fmt/unordered_map.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include "utils/containers/get_element_counts.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/unordered_set.cc b/lib/utils/test/src/utils/fmt/unordered_set.cc index 9dc8d236f1..f492ea844d 100644 --- a/lib/utils/test/src/utils/fmt/unordered_set.cc +++ b/lib/utils/test/src/utils/fmt/unordered_set.cc @@ -1,7 +1,7 @@ #include "utils/fmt/unordered_set.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/variant.cc b/lib/utils/test/src/utils/fmt/variant.cc index 3ada166de9..0c8dca35d7 100644 --- a/lib/utils/test/src/utils/fmt/variant.cc +++ b/lib/utils/test/src/utils/fmt/variant.cc @@ -1,5 +1,5 @@ #include "utils/fmt/variant.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/vector.cc b/lib/utils/test/src/utils/fmt/vector.cc index fee3eb34a5..91ef6c9efc 100644 --- a/lib/utils/test/src/utils/fmt/vector.cc +++ b/lib/utils/test/src/utils/fmt/vector.cc @@ -1,5 +1,5 @@ #include "utils/fmt/vector.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/graph/cow_ptr_t.cc b/lib/utils/test/src/utils/graph/cow_ptr_t.cc index 65088c19de..e6a6f9661e 100644 --- a/lib/utils/test/src/utils/graph/cow_ptr_t.cc +++ b/lib/utils/test/src/utils/graph/cow_ptr_t.cc @@ -1,5 +1,5 @@ #include "utils/graph/cow_ptr_t.h" -#include "test/utils/doctest.h" +#include #include #include #include diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..330628adfd --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,43 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_subgraph_incoming_edges(DataflowGraphView, " + "std::unordered_set") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2, o1}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + std::unordered_set input_node_set = {n2, n3}; + + std::unordered_set result = + get_subgraph_incoming_edges(g, input_node_set); + + std::unordered_set correct = { + DataflowEdge{o1, DataflowInput{n2, 0}}, + DataflowEdge{o1, DataflowInput{n3, 0}}, + DataflowEdge{o1, DataflowInput{n3, 2}}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc index 7e02686dde..779d0a9560 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -7,7 +7,8 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_outgoing_edges(DataflowGraphView, std::unordered_set") { + TEST_CASE("get_subgraph_outgoing_edges(DataflowGraphView, " + "std::unordered_set") { DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc index cfc912af6b..7a3237d432 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc @@ -1,9 +1,11 @@ -#include "test/utils/doctest.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_output_query.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/node_query.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("UnorderedSetDataflowGraph") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 2ebfe232b6..eca7aa6c79 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,5 +1,8 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" +#include "utils/containers/reversed.h" +#include "utils/containers/vector_of.h" #include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include @@ -9,6 +12,25 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_cbc_decomposition") { DiGraph g = DiGraph::create(); + // used to check that the cbc decomposition result is the same regardless + // of the order in which the graph edges are processed, as this is a + // property that should hold, and violations of this property have been a + // source of bugs in the past + auto check_cbc_decomposition_is_edge_order_invariant = + [](DiGraphView const &g) { + std::unordered_set edges = get_edges(g); + + std::vector edge_order1 = vector_of(edges); + std::vector edge_order2 = reversed(edge_order1); + + std::optional result1 = + get_cbc_decomposition_with_edge_order_internal(g, edge_order1); + std::optional result2 = + get_cbc_decomposition_with_edge_order_internal(g, edge_order2); + + CHECK(result1 == result2); + }; + SUBCASE("six-node diamond graph") { std::vector n = add_nodes(g, 6); add_edges(g, @@ -32,6 +54,8 @@ TEST_SUITE(FF_TEST_SUITE) { }}; CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); } SUBCASE("graph without any edges") { @@ -43,6 +67,27 @@ TEST_SUITE(FF_TEST_SUITE) { CompleteBipartiteCompositeDecomposition{{}}; CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); + } + + SUBCASE("irreducible n-graph (non-cbc graph)") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional result = + get_cbc_decomposition(g); + std::optional correct = + std::nullopt; + + CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc new file mode 100644 index 0000000000..17c8b8da27 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc @@ -0,0 +1,175 @@ +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_complete_bipartite_digraph(UndirectedGraphView, " + "std::unordered_set)") { + DiGraph g = DiGraph::create(); + + SUBCASE("simple bipartite graph") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + }); + + SUBCASE("source group") { + std::unordered_set group1 = {n.at(0), n.at(1), n.at(2)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("sink group") { + std::unordered_set group1 = {n.at(3), n.at(4)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + } + + SUBCASE("missing an edge (i.e., not complete)") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("extra edge (i.e., not bipartite)") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("flipped edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("group too small") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + } + + TEST_CASE("is_complete_bipartite_digraph(UndirectedGraphView)") { + DiGraph g = DiGraph::create(); + + SUBCASE("simple bipartite graph") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("missing an edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("extra edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index fd2f469f93..a635658755 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -4,6 +4,7 @@ #include "utils/containers/transform.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" @@ -139,5 +140,27 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_bidict == correct_bidict); } } + + SUBCASE("sp n-graph (inverse line graph does not exist)") { + // Tests that the inverse line graph of the sp n-graph + // + // a-b + // \ + // c-d + // + // does not exist + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional result = + get_inverse_line_graph(transitive_reduction(g)); + + CHECK_FALSE(result.has_value()); + } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc index 3ad506f40a..e675e6903f 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc new file mode 100644 index 0000000000..5f72355ed0 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc @@ -0,0 +1,50 @@ +#include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transitive_closure(DiGraphView)") { + DiGraph g = DiGraph::create(); + + SUBCASE("maximum number of new edges") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + DiGraphView result = transitive_closure(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc index b8a35346f4..1f9062a8ed 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" @@ -76,5 +77,66 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_edges == correct_edges); } } + + SUBCASE("longer paths") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + DiGraphView result = transitive_reduction(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + } + + SUBCASE("irreducible sp n-graph") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + DiGraphView result = transitive_reduction(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + CHECK(result_edges == correct_edges); + } + } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc new file mode 100644 index 0000000000..66b657eaaa --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc @@ -0,0 +1,51 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt GenericBinarySPDecompositionTree") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + std::string result = fmt::to_string(input); + std::string correct = ""; + + CHECK(result == correct); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + + std::string result = fmt::to_string(input); + std::string correct = (" " + "" + ">" + ">"); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + + std::string result = fmt::to_string(input); + std::string correct = (" " + "" + ">" + ">"); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..abae9286b6 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1,86 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_leaves(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5}; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 6}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 5}; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 6}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 5}; + + CHECK(result == correct); + } + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5))), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(2))); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {2, 2, 4, 4, 5}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc new file mode 100644 index 0000000000..92c556ad28 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -0,0 +1,41 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_left_child(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + CHECK_THROWS(get_left_child(input)); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(3)); + + GenericBinarySPDecompositionTree result = get_left_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(5); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(7)); + + GenericBinarySPDecompositionTree result = get_left_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(4); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..3de61d3313 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -0,0 +1,85 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_num_tree_nodes(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + int result = get_num_tree_nodes(input); + int correct = 1; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5))), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(2))); + + int result = get_num_tree_nodes(input); + int correct = 9; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc new file mode 100644 index 0000000000..33b5d37955 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -0,0 +1,41 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_right_child(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + CHECK_THROWS(get_right_child(input)); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(3)); + + GenericBinarySPDecompositionTree result = get_right_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(3); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(7)); + + GenericBinarySPDecompositionTree result = get_right_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(7); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc new file mode 100644 index 0000000000..e7025dbfad --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc @@ -0,0 +1,117 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree leaf_5 = + make_generic_binary_sp_leaf(5); + size_t leaf_5_hash = get_std_hash(leaf_5); + + SUBCASE("leaves with same labels hash to the same value") { + GenericBinarySPDecompositionTree also_leaf_5 = + make_generic_binary_sp_leaf(5); + size_t also_leaf_5_hash = get_std_hash(also_leaf_5); + + CHECK(leaf_5_hash == also_leaf_5_hash); + } + + SUBCASE("leaves with different labels hash to different values") { + GenericBinarySPDecompositionTree leaf_6 = + make_generic_binary_sp_leaf(6); + size_t leaf_6_hash = get_std_hash(leaf_6); + + CHECK(leaf_5_hash != leaf_6_hash); + } + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree series_5_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t series_5_6_hash = get_std_hash(series_5_6); + + SUBCASE("same children lead to the same hash") { + GenericBinarySPDecompositionTree also_series_5_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t also_series_5_6_hash = get_std_hash(also_series_5_6); + + CHECK(series_5_6_hash == also_series_5_6_hash); + } + + SUBCASE("hash is order dependent") { + GenericBinarySPDecompositionTree series_6_5 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(6), + make_generic_binary_sp_leaf(5)); + size_t series_6_5_hash = get_std_hash(series_6_5); + + CHECK(series_5_6_hash != series_6_5_hash); + } + + SUBCASE("different left child leads to different hash") { + GenericBinarySPDecompositionTree series_4_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(6)); + size_t series_4_6_hash = get_std_hash(series_4_6); + + CHECK(series_5_6_hash != series_4_6_hash); + } + + SUBCASE("different right child leads to different hash") { + GenericBinarySPDecompositionTree series_5_7 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + size_t series_5_7_hash = get_std_hash(series_5_7); + + CHECK(series_5_6_hash != series_5_7_hash); + } + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree parallel_5_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t parallel_5_6_hash = get_std_hash(parallel_5_6); + + SUBCASE("same children lead to the same hash") { + GenericBinarySPDecompositionTree also_parallel_5_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t also_parallel_5_6_hash = get_std_hash(also_parallel_5_6); + + CHECK(parallel_5_6_hash == also_parallel_5_6_hash); + } + + SUBCASE("hash is order dependent") { + GenericBinarySPDecompositionTree parallel_6_5 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(6), + make_generic_binary_sp_leaf(5)); + size_t parallel_6_5_hash = get_std_hash(parallel_6_5); + + CHECK(parallel_5_6_hash != parallel_6_5_hash); + } + + SUBCASE("different left child leads to different hash") { + GenericBinarySPDecompositionTree parallel_4_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(6)); + size_t parallel_4_6_hash = get_std_hash(parallel_4_6); + + CHECK(parallel_5_6_hash != parallel_4_6_hash); + } + + SUBCASE("different right child leads to different hash") { + GenericBinarySPDecompositionTree parallel_5_7 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + size_t parallel_5_7_hash = get_std_hash(parallel_5_7); + + CHECK(parallel_5_6_hash != parallel_5_7_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..7a8756c6cc --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1,102 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_left_associative(" + "GenericBinarySPDecompositionTree)") { + int n1 = 1; + int n2 = 2; + int n3 = 3; + int n4 = 4; + + SUBCASE("input is actually left associative") { + SUBCASE("just node") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(n1); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n3), + make_generic_binary_sp_leaf(n4))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not left associative") { + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..3cf87368ab --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1,102 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_right_associative(" + "GenericBinarySPDecompositionTree)") { + int n1 = 1; + int n2 = 2; + int n3 = 3; + int n4 = 4; + + SUBCASE("input is actually right associative") { + SUBCASE("just node") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(n1); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n3), + make_generic_binary_sp_leaf(n4))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not right associative") { + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc new file mode 100644 index 0000000000..cc234bacf8 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc @@ -0,0 +1,131 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer>") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree tt = make_generic_binary_sp_leaf(5); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree tt = + make_generic_binary_series_split(make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5)); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "series"}, + { + "value", + { + {"__type", "GenericBinarySeriesSplit"}, + { + "left_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }, + }, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree tt = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5)); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "parallel"}, + { + "value", + { + {"__type", "GenericBinaryParallelSplit"}, + { + "left_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }, + }, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc new file mode 100644 index 0000000000..4ede4e84b5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -0,0 +1,28 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transform(GenericBinarySPDecompositionTree, F)") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split(make_generic_binary_sp_leaf(1), + make_generic_binary_sp_leaf(4)), + make_generic_binary_sp_leaf(2)); + + GenericBinarySPDecompositionTree result = + transform(input, [](int x) { return std::to_string(x); }); + + GenericBinarySPDecompositionTree correct = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(std::string{"1"}), + make_generic_binary_sp_leaf(std::string{"4"})), + make_generic_binary_sp_leaf(std::string{"2"})); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..1e3217a2de --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,95 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/rapidcheck.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("left_associative_binary_sp_tree_from_nary(" + "SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf_node(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = make_series_split( + make_series_split(make_leaf_node(n1), make_leaf_node(n2)), + make_leaf_node(n3)); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + // we use multiple checks here because SerialParallelDecomposition's + // ParallelSplit is unordered, so there are multiple possible + // left-associative binary SP trees + CHECK(is_binary_sp_tree_left_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = {n1, n2, n3}; + + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + CHECK(is_binary_sp_tree_left_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = { + n1, n2, n3, n3, n5, n6, n4, n5}; + + CHECK(result_nodes == correct_nodes); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc new file mode 100644 index 0000000000..0befbde5cc --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -0,0 +1,132 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("nary_sp_tree_from_binary(BinarySPDecompositionTree)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("leaf") { + BinarySPDecompositionTree input = make_leaf_node(n1); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + + CHECK(result == correct); + } + + SUBCASE("left associative series") { + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf_node(n2), make_leaf_node(n1)), + make_leaf_node(n3)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("right associative series") { + BinarySPDecompositionTree input = make_series_split( + make_leaf_node(n2), + make_series_split(make_leaf_node(n1), make_leaf_node(n3))); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("series with duplicate children") { + BinarySPDecompositionTree input = + make_series_split(make_leaf_node(n1), make_leaf_node(n1)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n1, n1}}; + + CHECK(get_nodes(result).size() == 2); + CHECK(result == correct); + } + + SUBCASE("left associative parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf_node(n2), make_leaf_node(n1)), + make_leaf_node(n3)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("right associative parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_leaf_node(n2), + make_parallel_split(make_leaf_node(n1), make_leaf_node(n3))); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("parallel with duplicate children") { + BinarySPDecompositionTree input = + make_parallel_split(make_leaf_node(n1), make_leaf_node(n1)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n1, n1}}; + + CHECK(get_nodes(result).size() == 2); + CHECK(result == correct); + } + + SUBCASE("nested") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split( + make_parallel_split( + make_leaf_node(n1), + make_series_split( + make_series_split(make_series_split(make_leaf_node(n2), + make_leaf_node(n3)), + make_leaf_node(n3)), + make_leaf_node(n5))), + make_series_split(make_leaf_node(n6), make_leaf_node(n4))), + make_leaf_node(n5)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..db1b440481 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,93 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("right_associative_binary_sp_tree_from_nary(" + "SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf_node(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = make_series_split( + make_leaf_node(n1), + make_series_split(make_leaf_node(n2), make_leaf_node(n3))); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + // we use multiple checks here because SerialParallelDecomposition's + // ParallelSplit is unordered, so there are multiple possible + // right-associative binary SP trees + CHECK(is_binary_sp_tree_right_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = {n1, n2, n3}; + + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + CHECK(is_binary_sp_tree_right_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = { + n1, n2, n3, n3, n5, n6, n4, n5}; + + CHECK(result_nodes == correct_nodes); + } + } +} diff --git a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc similarity index 50% rename from lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc rename to lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 04d82bf1d8..45f796c824 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include @@ -6,47 +6,47 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_serial_parallel_decomposition (base case)") { + TEST_CASE("get_series_parallel_decomposition (base case)") { DiGraph g = DiGraph::create(); Node n = g.add_node(); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{n}; + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{n}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (parallel)") { + TEST_CASE("get_series_parallel_decomposition (parallel)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 2); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ParallelSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{ParallelSplit{ n.at(0), n.at(1), }}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (serial)") { + TEST_CASE("get_series_parallel_decomposition (serial)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 2); g.add_edge(DirectedEdge{n.at(0), n.at(1)}); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ n.at(0), n.at(1), }}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (composite)") { + TEST_CASE("get_series_parallel_decomposition (composite)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 3); add_edges(g, @@ -55,11 +55,11 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(0), n.at(2)}, }); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ n.at(0), ParallelSplit{ n.at(1), @@ -70,7 +70,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (diamond graph)") { + TEST_CASE("get_series_parallel_decomposition (diamond graph)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 6); @@ -85,15 +85,15 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(4), n.at(5)}, }); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ n.at(0), ParallelSplit{ - SerialSplit{ + SeriesSplit{ n.at(1), n.at(3), }, - SerialSplit{ + SeriesSplit{ n.at(2), n.at(4), }, @@ -101,13 +101,13 @@ TEST_SUITE(FF_TEST_SUITE) { n.at(5), }}; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (all-to-all connection)") { + TEST_CASE("get_series_parallel_decomposition (all-to-all connection)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); @@ -120,9 +120,9 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(1), n.at(3)}, }); - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ ParallelSplit{ n.at(0), n.at(1), @@ -134,13 +134,13 @@ TEST_SUITE(FF_TEST_SUITE) { }, }; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (non-sp graph)") { + TEST_CASE("get_series_parallel_decomposition (non-sp graph)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); @@ -153,9 +153,39 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(1), n.at(3)}, }); - std::optional correct = std::nullopt; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional correct = std::nullopt; + std::optional result = + get_series_parallel_decomposition(g); + + CHECK(result == correct); + } + + TEST_CASE( + "get_series_parallel_decomposition (requires transitive reduction)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ + n.at(0), + n.at(1), + n.at(2), + n.at(3), + }, + }; + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc similarity index 83% rename from lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc rename to lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 4560f95ff7..3a486c7094 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/fmt/variant.h" #include @@ -8,11 +8,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("flatten_ast") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{2}, Node{3}, @@ -25,7 +25,7 @@ TEST_SUITE(FF_TEST_SUITE) { flatten_ast(input); std::variant correct = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, Node{2}, diff --git a/lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc similarity index 99% rename from lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc rename to lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc index 8259d256d3..a62f528bcf 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" #include "utils/graph/multidigraph/algorithms/add_nodes.h" diff --git a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc similarity index 66% rename from lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc rename to lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 7cf17c3fee..f5766c9fdd 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,5 +1,5 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/fmt/unordered_set.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include using namespace ::FlexFlow; @@ -7,20 +7,20 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("to_final_ast (base case)") { std::variant input = Node{1}; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = SerialParallelDecomposition{Node{1}}; + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{Node{1}}; CHECK(result == correct); } TEST_CASE("to_final_ast (serial)") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, {Node{1}, Node{2}}, }; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = SerialParallelDecomposition{ - SerialSplit{{ + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{ Node{1}, Node{2}, }}, @@ -30,11 +30,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("to_final_ast (composite)") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{0}, IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, IntermediateSpDecompositionTree{ @@ -55,9 +55,9 @@ TEST_SUITE(FF_TEST_SUITE) { Node{5}, }}; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = - SerialParallelDecomposition{SerialSplit{{ + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{ Node{0}, Node{1}, ParallelSplit{{ @@ -70,55 +70,55 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE("get_nodes(SerialParallelDecomposition)") { - SerialParallelDecomposition input = - SerialParallelDecomposition{SerialSplit{{ + TEST_CASE("get_nodes(SeriesParallelDecomposition)") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{ ParallelSplit{{ Node{1}, Node{2}, }}, - Node{3}, + Node{2}, ParallelSplit{{ Node{4}, Node{5}, }}, }}}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, - Node{3}, + Node{2}, Node{4}, Node{5}, }; CHECK(result == correct); } - TEST_CASE("get_nodes(SerialSplit)") { + TEST_CASE("get_nodes(SeriesSplit)") { ParallelSplit input = ParallelSplit{{ Node{1}, - SerialSplit{{ + SeriesSplit{{ Node{2}, ParallelSplit{{ Node{3}, Node{4}, }}, }}, - SerialSplit{{ - Node{5}, + SeriesSplit{{ + Node{1}, Node{6}, }}, Node{7}, }}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, Node{3}, Node{4}, - Node{5}, + Node{1}, Node{6}, Node{7}, }; @@ -129,9 +129,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nodes(ParallelSplit)") { ParallelSplit input = ParallelSplit{{ Node{1}, - SerialSplit{{ + SeriesSplit{{ Node{2}, - Node{3}, + Node{4}, ParallelSplit{{ Node{4}, Node{5}, @@ -139,11 +139,11 @@ TEST_SUITE(FF_TEST_SUITE) { }}, }}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, - Node{3}, + Node{4}, Node{4}, Node{5}, }; @@ -153,8 +153,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nodes(Node)") { Node input = Node{5}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = {input}; + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = {input}; CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc similarity index 99% rename from lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc rename to lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index e4d53b4136..c6b45ec6ce 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/set_minus.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" diff --git a/lib/utils/test/src/utils/hash/multiset.cc b/lib/utils/test/src/utils/hash/multiset.cc new file mode 100644 index 0000000000..5c2e01fda8 --- /dev/null +++ b/lib/utils/test/src/utils/hash/multiset.cc @@ -0,0 +1,34 @@ +#include "utils/hash/multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::multiset input = {1, 2, 2, 1, 5}; + size_t input_hash = get_std_hash(input); + + SUBCASE("same values have the same hash") { + std::multiset also_input = {2, 1, 2, 5, 1}; + size_t also_input_hash = get_std_hash(input); + + CHECK(input_hash == also_input_hash); + } + + SUBCASE("different values have different hashes") { + SUBCASE("different number of duplicates") { + std::multiset other = {1, 2, 2, 1, 5, 5}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + + SUBCASE("different elements") { + std::multiset other = {1, 2, 2, 1, 6}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/hash/unordered_multiset.cc b/lib/utils/test/src/utils/hash/unordered_multiset.cc new file mode 100644 index 0000000000..6c730fad3c --- /dev/null +++ b/lib/utils/test/src/utils/hash/unordered_multiset.cc @@ -0,0 +1,34 @@ +#include "utils/hash/unordered_multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::unordered_multiset input = {1, 2, 2, 1, 5}; + size_t input_hash = get_std_hash(input); + + SUBCASE("same values have the same hash") { + std::unordered_multiset also_input = {2, 1, 2, 5, 1}; + size_t also_input_hash = get_std_hash(input); + + CHECK(input_hash == also_input_hash); + } + + SUBCASE("different values have different hashes") { + SUBCASE("different number of duplicates") { + std::unordered_multiset other = {1, 2, 2, 1, 5, 5}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + + SUBCASE("different elements") { + std::unordered_multiset other = {1, 2, 2, 1, 6}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/json/optional.cc b/lib/utils/test/src/utils/json/optional.cc new file mode 100644 index 0000000000..61f5868c53 --- /dev/null +++ b/lib/utils/test/src/utils/json/optional.cc @@ -0,0 +1,49 @@ +#include "utils/json/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer>") { + SUBCASE("to_json") { + SUBCASE("has value") { + std::optional input = 5; + + nlohmann::json result = input; + nlohmann::json correct = 5; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + std::optional input = std::nullopt; + + nlohmann::json result = input; + nlohmann::json correct = nullptr; + + CHECK(result == correct); + } + } + + SUBCASE("from_json") { + SUBCASE("has value") { + nlohmann::json input = 5; + + std::optional result = input; + std::optional correct = 5; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + nlohmann::json input = nullptr; + + std::optional result = input.get>(); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/optional.cc b/lib/utils/test/src/utils/rapidcheck/optional.cc similarity index 67% rename from lib/utils/test/src/utils/optional.cc rename to lib/utils/test/src/utils/rapidcheck/optional.cc index 16c9e964cb..96b17a5400 100644 --- a/lib/utils/test/src/utils/optional.cc +++ b/lib/utils/test/src/utils/rapidcheck/optional.cc @@ -1,7 +1,8 @@ -#include "utils/optional.h" -#include "test/utils/doctest.h" +#include "utils/rapidcheck/optional.h" #include "test/utils/rapidcheck.h" -#include +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE( From 64554159ab66da117ee243fdac0592b7aeabd613 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 16 Sep 2024 11:17:19 -0700 Subject: [PATCH 2/3] Add interface for differentiating inputs and weights in CG & PCG (#1493) * Add interface for differentiating inputs and weights in CG/PCG * Format * Address Reyna PR comments * fix bugs from merge * Format --- lib/kernels/include/kernels/legion_dim.h | 2 +- .../include/local-execution/serialization.h | 2 +- .../src/local_cost_estimator.cc | 5 +- .../src/models/transformer/transformer.cc | 9 +- .../op-attrs/computation_graph_op_attrs.h | 3 + .../op-attrs/{ => dim_ordered}/dim_ordered.h | 0 .../include/op-attrs/dim_ordered/enumerate.h | 2 +- .../op-attrs/dim_ordered/ff_ordered_of.h | 2 +- .../include/op-attrs/dim_ordered/get_idxs.h | 2 +- .../include/op-attrs/dim_ordered/slice.h | 2 +- .../include/op-attrs/dim_ordered/transform.h | 2 +- .../include/op-attrs/dim_ordered/zip.h | 2 +- .../op-attrs/get_incoming_tensor_roles.h | 17 ++ .../op-attrs/incoming_tensor_role.enum.toml | 14 + lib/op-attrs/include/op-attrs/ops/attention.h | 20 ++ .../include/op-attrs/ops/batch_matmul.h | 3 + lib/op-attrs/include/op-attrs/ops/conv_2d.h | 4 + .../include/op-attrs/ops/layer_norm.h | 4 + lib/op-attrs/include/op-attrs/ops/linear.h | 10 +- lib/op-attrs/include/op-attrs/ops/topk.h | 2 + .../op-attrs/ops/transpose_attrs.struct.toml | 2 +- .../op-attrs/parallel_tensor_dims.struct.toml | 2 +- .../include/op-attrs/pcg_operator_attrs.h | 4 +- .../op-attrs/pcg_operator_attrs.variant.toml | 5 + .../include/op-attrs/tensor_dims.struct.toml | 2 +- .../op-attrs/computation_graph_op_attrs.cc | 18 ++ .../src/op-attrs/get_incoming_tensor_roles.cc | 103 +++++++ lib/op-attrs/src/op-attrs/ops/attention.cc | 19 ++ lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 14 + lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 13 + lib/op-attrs/src/op-attrs/ops/linear.cc | 23 +- lib/op-attrs/src/op-attrs/ops/topk.cc | 4 + .../src/op-attrs/pcg_operator_attrs.cc | 57 +--- .../op-attrs/computation_graph_op_attrs.cc | 15 + .../test/src/{ => op-attrs}/datatype.cc | 0 .../dim_ordered/dim_ordered.cc} | 2 +- .../src/{ => op-attrs}/dim_ordered/slice.cc | 0 .../src/op-attrs/get_incoming_tensor_roles.cc | 26 ++ .../test/src/{ => op-attrs}/ops/attention.cc | 49 ++++ .../src/{ => op-attrs}/ops/batch_matmul.cc | 0 .../test/src/op-attrs/ops/batch_norm_attrs.cc | 15 + .../test/src/{ => op-attrs}/ops/cast.cc | 0 .../test/src/{ => op-attrs}/ops/combine.cc | 0 .../test/src/{ => op-attrs}/ops/conv_2d.cc | 42 +++ .../src/{ => op-attrs}/ops/element_binary.cc | 0 .../src/{ => op-attrs}/ops/element_unary.cc | 0 .../test/src/{ => op-attrs}/ops/embedding.cc | 0 .../test/src/op-attrs/ops/layer_norm.cc | 36 +++ .../test/src/{ => op-attrs}/ops/linear.cc | 86 ++++-- .../test/src/{ => op-attrs}/ops/reduction.cc | 0 .../src/{ => op-attrs}/ops/repartition.cc | 0 .../test/src/{ => op-attrs}/ops/replicate.cc | 0 .../test/src/op-attrs/pcg_operator_attrs.cc | 17 ++ .../regularizer_attrs.cc} | 0 lib/op-attrs/test/src/test_operator_attrs.cc | 37 --- lib/pcg/include/pcg/computation_graph.h | 6 + .../include/pcg/computation_graph_builder.h | 39 +-- .../parallel_computation_graph.h | 11 +- .../parallel_computation_graph_builder.h | 5 +- .../include/pcg/strided_rectangle.struct.toml | 2 +- lib/pcg/src/pcg/computation_graph.cc | 41 +++ lib/pcg/src/pcg/computation_graph_builder.cc | 270 ++++++++---------- .../parallel_computation_graph.cc | 48 +++- .../parallel_computation_graph_builder.cc | 29 +- lib/pcg/test/src/pcg/computation_graph.cc | 206 +++++++++++++ .../parallel_computation_graph.cc | 227 +++++++++++++++ .../parallel_computation_graph_builder.cc | 60 ++-- .../operator_pattern/get_attribute.h | 60 ++-- .../operator_attribute_value.variant.toml | 3 + .../operator_pattern/get_attribute.cc | 12 + .../test/src/substitutions/pcg_pattern.cc | 18 +- .../perform_shape_inference.cc | 2 +- 72 files changed, 1352 insertions(+), 385 deletions(-) rename lib/op-attrs/include/op-attrs/{ => dim_ordered}/dim_ordered.h (100%) create mode 100644 lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h create mode 100644 lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml create mode 100644 lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc create mode 100644 lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc rename lib/op-attrs/test/src/{ => op-attrs}/datatype.cc (100%) rename lib/op-attrs/test/src/{test_dim_ordered.cc => op-attrs/dim_ordered/dim_ordered.cc} (89%) rename lib/op-attrs/test/src/{ => op-attrs}/dim_ordered/slice.cc (100%) create mode 100644 lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc rename lib/op-attrs/test/src/{ => op-attrs}/ops/attention.cc (87%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/batch_matmul.cc (100%) create mode 100644 lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc rename lib/op-attrs/test/src/{ => op-attrs}/ops/cast.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/combine.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/conv_2d.cc (85%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/element_binary.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/element_unary.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/embedding.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/linear.cc (74%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/reduction.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/repartition.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/replicate.cc (100%) create mode 100644 lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc rename lib/op-attrs/test/src/{test_regularizer_attrs.cc => op-attrs/regularizer_attrs.cc} (100%) delete mode 100644 lib/op-attrs/test/src/test_operator_attrs.cc create mode 100644 lib/pcg/test/src/pcg/computation_graph.cc diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index aafbd2cdcb..e4dd9723b8 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H #include "kernels/legion_dim_t.dtg.h" -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/serialization.h b/lib/local-execution/include/local-execution/serialization.h index a260519a55..2fc4b4b706 100644 --- a/lib/local-execution/include/local-execution/serialization.h +++ b/lib/local-execution/include/local-execution/serialization.h @@ -3,7 +3,7 @@ #include "kernels/device.h" #include "kernels/nccl.h" -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/required.h" #include "utils/strong_typedef.h" #include "utils/type_traits.h" diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index 5203991f25..b42aec10bb 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -69,7 +69,10 @@ CostDetails LocalCostEstimator::estimate_cost( std::vector output_tensor_ids = cg_builder.add_layer(layer_attrs, input_tensor_ids, - get_vector_piece_attrs(weights), + transform(get_vector_piece_attrs(weights), + [&](TensorAttrs const &a) { + return cg_builder.create_weight(a); + }), get_vector_piece_attrs(outputs)); LocalTrainingBacking local_backing(allocator, diff --git a/lib/models/src/models/transformer/transformer.cc b/lib/models/src/models/transformer/transformer.cc index e179359940..173a1b291c 100644 --- a/lib/models/src/models/transformer/transformer.cc +++ b/lib/models/src/models/transformer/transformer.cc @@ -42,7 +42,8 @@ tensor_guid_t create_transformer_encoder_layer(ComputationGraphBuilder &cgb, config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention)); @@ -88,7 +89,8 @@ tensor_guid_t config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention)); @@ -107,7 +109,8 @@ tensor_guid_t config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent(cgb.computation_graph, input, mha)); tensor_guid_t mha_normalized = diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h index 03f38bb8f9..52e6e12a8c 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H #include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "utils/record_formatter.h" namespace FlexFlow { OperatorType get_op_type(ComputationGraphOpAttrs const &); RecordFormatter as_dot(ComputationGraphOpAttrs const &); +ComputationGraphOpAttrs + compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h similarity index 100% rename from lib/op-attrs/include/op-attrs/dim_ordered.h rename to lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h index f9f6d00532..38e7da4bb2 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/bidict/bidict.h" #include "utils/containers/count.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h index c843ed3842..8cc1bf3a51 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_OF_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_OF_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h index 560862677e..7343dc0e69 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/count.h" #include "utils/containers/transform.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index d39bac1bde..23b971da6b 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/subvec.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h index ae6e552243..4fd3df0abb 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h index 023dcfc586..cc8b050f50 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/vector_of.h" #include "utils/containers/zip.h" diff --git a/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h b/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h new file mode 100644 index 0000000000..b395736773 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_INCOMING_TENSOR_ROLES_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_INCOMING_TENSOR_ROLES_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/incoming_tensor_role.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +std::vector + get_incoming_tensor_roles(ComputationGraphOpAttrs const &, int num_inputs); +std::vector + get_incoming_tensor_roles(PCGOperatorAttrs const &, int num_inputs); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml b/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml new file mode 100644 index 0000000000..427701c801 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "IncomingTensorRole" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "INPUT" + +[[values]] +name = "WEIGHT" diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 40f57d08af..e06d795c04 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_ATTENTION_ATTRS_H #define _FLEXFLOW_ATTENTION_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" #include "op-attrs/ops/attention_attrs.dtg.h" @@ -37,6 +38,9 @@ int get_kvSeqLength(MultiHeadAttentionInputs const &); int get_num_samples(MultiHeadAttentionParallelInputs const &); int get_num_samples(MultiHeadAttentionInputs const &); +std::vector + get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs const &); + tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, TensorShape const &input_q, @@ -58,6 +62,22 @@ tl::expected TensorShape const &input_k, TensorShape const &input_v); +tl::expected + get_weights_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); +tl::expected + get_input_bias_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); +tl::expected + get_output_bias_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); + tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, ParallelTensorShape const &input_q, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 57760d1110..574b4ef579 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H #include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); + bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 72d1123c39..ae9f9249c6 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_CONV_2D_ATTRS_H #define _FLEXFLOW_CONV_2D_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/conv_2d_attrs.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" @@ -10,6 +11,9 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Conv2DAttrs); +std::vector + get_conv2d_incoming_tensor_roles(Conv2DAttrs const &); + TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &input); TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 29b0b2f514..0fbadae2a1 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -8,6 +9,9 @@ namespace FlexFlow { +std::vector + get_layer_norm_incoming_tensor_roles(LayerNormAttrs const &); + tl::expected get_output_shape(LayerNormAttrs const &, TensorShape const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 795ba19ae8..065cc7e38e 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LINEAR_ATTRS_H #define _FLEXFLOW_LINEAR_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -10,20 +11,23 @@ namespace FlexFlow { +std::vector + get_linear_incoming_tensor_roles(LinearAttrs const &); + CHECK_VALID_OP_ATTR(LinearAttrs); RecordFormatter as_dot(LinearAttrs const &); tl::expected - get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); + get_projection_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_output_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected - get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input); + get_projection_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index bd11f0ae91..d6de90903a 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -4,11 +4,13 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(TopKAttrs); +TensorShape get_output_shape(TopKAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml index 756091f653..0dc30d9a79 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -12,7 +12,7 @@ features = [ includes = [ "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml index ae6eab1e58..f24fa12309 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", "op-attrs/shard_parallel_dim.dtg.h", "op-attrs/replica_parallel_dim_set.dtg.h", "", diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h index 08167fe3d9..723c05298d 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -8,8 +8,8 @@ namespace FlexFlow { bool is_parallel_op(PCGOperatorAttrs const &); OperatorType get_op_type(PCGOperatorAttrs const &); -ComputationGraphOpAttrs - compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); +PCGOperatorAttrs + pcg_op_attrs_from_compgraph_op_attrs(ComputationGraphOpAttrs const &); RecordFormatter as_dot(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml index 8617c5fd64..a44d712dbf 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -13,6 +13,7 @@ includes = [ "op-attrs/ops/attention_attrs.dtg.h", "op-attrs/ops/batch_matmul.dtg.h", "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", "op-attrs/ops/cast_attrs.dtg.h", "op-attrs/ops/combine_attrs.dtg.h", "op-attrs/ops/concat_attrs.dtg.h", @@ -49,6 +50,10 @@ key = "batch_matmul" type = "::FlexFlow::BatchNormAttrs" key = "batch_norm" +[[values]] +type = "::FlexFlow::BroadcastAttrs" +key = "broadcast" + [[values]] type = "::FlexFlow::CastAttrs" key = "cast" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml index cff8e08b0f..b262dd32b6 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml @@ -9,7 +9,7 @@ features = [ "fmt", ] includes = [ - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index 054930cebd..c4ae7b31e5 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -23,4 +23,22 @@ RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { }); } +ComputationGraphOpAttrs + compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &op) { + auto fail_on_parallel_op = [](auto const &attrs) -> ComputationGraphOpAttrs { + throw mk_runtime_error( + fmt::format("Encountered parallel operator in " + "compgraph_op_attrs_from_pcg_op_attrs: {}", + attrs)); + }; + + return op.visit(overload{ + [&](CombineAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](ReductionAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](RepartitionAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](ReplicateAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [](auto const &attrs) { return ComputationGraphOpAttrs{attrs}; }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc new file mode 100644 index 0000000000..c7febde1d6 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -0,0 +1,103 @@ +#include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::vector get_incoming_tensor_roles( + ComputationGraphOpAttrs const &comp_graph_op_attrs, int num_incoming) { + return get_incoming_tensor_roles( + pcg_op_attrs_from_compgraph_op_attrs(comp_graph_op_attrs), num_incoming); +} + +std::vector + get_incoming_tensor_roles(PCGOperatorAttrs const &pcg_op_attrs, + int num_incoming) { + return pcg_op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT}; + }, + [](BatchNormAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](BroadcastAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](CastAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](CombineAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [&](ConcatAttrs const &) { + return std::vector(num_incoming, IncomingTensorRole::INPUT); + }, + [](Conv2DAttrs const &attrs) { + return get_conv2d_incoming_tensor_roles(attrs); + }, + [](DropoutAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ElementBinaryAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT}; + }, + [](ElementUnaryAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](EmbeddingAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT}; + }, + [](FlatAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](GatherAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](InputAttrs const &) { return std::vector{}; }, + [](LayerNormAttrs const &attrs) { + return get_layer_norm_incoming_tensor_roles(attrs); + }, + [](LinearAttrs const &attrs) { + return get_linear_incoming_tensor_roles(attrs); + }, + [](MultiHeadAttentionAttrs const &attrs) { + return get_attention_incoming_tensor_roles(attrs); + }, + [](NoopAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](Pool2DAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReduceAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReductionAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](RepartitionAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReplicateAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReverseAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReshapeAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](SplitAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](SoftmaxAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](TopKAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](TransposeAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](WeightAttrs const &) { return std::vector{}; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 036daa6e67..483d832fee 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -3,6 +3,7 @@ #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" +#include "utils/containers/extend.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -91,6 +92,24 @@ int get_num_samples(MultiHeadAttentionInputs const &inputs) { return inputs.batch_size; } +std::vector + get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs const &attrs) { + + std::vector roles = std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.bias) { + extend(roles, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return roles; +} + tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, TensorShape const &input_q, diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 03ae18a1d9..f77daf451f 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -5,6 +5,20 @@ namespace FlexFlow { +std::vector + get_conv2d_incoming_tensor_roles(Conv2DAttrs const &attrs) { + std::vector result = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.use_bias) { + result.push_back(IncomingTensorRole::WEIGHT); + } + + return result; +} + TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index b9603d7850..d3c00efbb9 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -6,10 +6,23 @@ #include "utils/containers/all_of.h" #include "utils/containers/any_of.h" #include "utils/containers/contains.h" +#include "utils/containers/extend.h" #include "utils/containers/filter.h" namespace FlexFlow { +std::vector + get_layer_norm_incoming_tensor_roles(LayerNormAttrs const &attrs) { + std::vector result = {IncomingTensorRole::INPUT}; + + if (attrs.elementwise_affine) { + extend(result, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return result; +} + static std::optional check_input_shape(LayerNormAttrs const &attrs, TensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index 24a8250690..feac647216 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -8,6 +8,20 @@ namespace FlexFlow { +std::vector + get_linear_incoming_tensor_roles(LinearAttrs const &attrs) { + std::vector result = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.use_bias) { + result.push_back(IncomingTensorRole::WEIGHT); + } + + return result; +} + RecordFormatter as_dot(LinearAttrs const &attrs) { RecordFormatter r; @@ -25,7 +39,8 @@ RecordFormatter as_dot(LinearAttrs const &attrs) { } tl::expected - get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + get_projection_shape(LinearAttrs const &attrs, + TensorShape const &input_shape) { size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); return TensorShape{ @@ -56,11 +71,11 @@ tl::expected } tl::expected - get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input) { + get_projection_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { TensorShape unpar = ({ tl::expected result_unpar = - get_kernel_shape(attrs, get_reduced_shape(input)); + get_projection_shape(attrs, get_reduced_shape(input)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc index 9d2fd35a94..7a6868340b 100644 --- a/lib/op-attrs/src/op-attrs/ops/topk.cc +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(TopKAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc index 0bb134da6b..4fe01c2c1a 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -15,56 +15,6 @@ OperatorType get_op_type(PCGOperatorAttrs const &attrs) { [](auto const &x) { return get_op_type(x); }); } -ComputationGraphOpAttrs - compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &op) { - return op.visit(overload{ - [](BatchMatmulAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](BatchNormAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](CastAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ConcatAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](Conv2DAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](DropoutAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ElementBinaryAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](ElementUnaryAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](EmbeddingAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](FlatAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](GatherAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](InputAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](LayerNormAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](LinearAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](MultiHeadAttentionAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](NoopAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](Pool2DAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReduceAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReverseAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReshapeAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](SplitAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](SoftmaxAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](TopKAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](TransposeAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](auto const &attrs) -> ComputationGraphOpAttrs { - throw mk_runtime_error(fmt::format( - "Cannot convert parallel op to non-parallel, received {}", attrs)); - }, - }); -} - RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { return attrs.visit(overload{ [](LinearAttrs const &l) { return as_dot(l); }, @@ -76,4 +26,11 @@ RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { }); } +PCGOperatorAttrs pcg_op_attrs_from_compgraph_op_attrs( + ComputationGraphOpAttrs const &cg_attrs) { + return cg_attrs.visit(overload{ + [](auto const &attrs) { return PCGOperatorAttrs{attrs}; }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc new file mode 100644 index 0000000000..42ea07e6b5 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -0,0 +1,15 @@ +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ComputationGraphOpAttrs to/from json") { + ComputationGraphOpAttrs correct = + ComputationGraphOpAttrs{BatchNormAttrs{true}}; + nlohmann::json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/datatype.cc b/lib/op-attrs/test/src/op-attrs/datatype.cc similarity index 100% rename from lib/op-attrs/test/src/datatype.cc rename to lib/op-attrs/test/src/op-attrs/datatype.cc diff --git a/lib/op-attrs/test/src/test_dim_ordered.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc similarity index 89% rename from lib/op-attrs/test/src/test_dim_ordered.cc rename to lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc index ac05767800..d7901a0c53 100644 --- a/lib/op-attrs/test/src/test_dim_ordered.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc @@ -1,5 +1,5 @@ +#include "op-attrs/dim_ordered/dim_ordered.h" #include "doctest/doctest.h" -#include "op-attrs/dim_ordered.h" #include "test/utils/rapidcheck.h" using namespace FlexFlow; diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc similarity index 100% rename from lib/op-attrs/test/src/dim_ordered/slice.cc rename to lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc new file mode 100644 index 0000000000..60dedfe70a --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -0,0 +1,26 @@ +#include "op-attrs/get_incoming_tensor_roles.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "get_incoming_tensor_roles(ComputationGraphOpAttrs, int num_incoming)") { + SUBCASE("Concat") { + int num_incoming = 4; + ComputationGraphOpAttrs attrs = + ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}, num_incoming}}; + + std::vector result = + get_incoming_tensor_roles(attrs, num_incoming); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/ops/attention.cc b/lib/op-attrs/test/src/op-attrs/ops/attention.cc similarity index 87% rename from lib/op-attrs/test/src/ops/attention.cc rename to lib/op-attrs/test/src/op-attrs/ops/attention.cc index 2fb804ca8c..eca8559b21 100644 --- a/lib/op-attrs/test/src/ops/attention.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/attention.cc @@ -7,6 +7,55 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs)") { + auto make_attrs = [](bool bias) { + return MultiHeadAttentionAttrs{ + /*embed_dim=*/32, + /*num_heads=*/10, + /*kdim=*/32, + /*vdim=*/32, + /*dropout=*/0.0, + /*bias=*/bias, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + }; + + SUBCASE("without bias") { + MultiHeadAttentionAttrs attrs = make_attrs(/*bias=*/false); + + tl::expected, std::string> result = + get_attention_incoming_tensor_roles(attrs); + tl::expected, std::string> correct = + std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("with bias") { + MultiHeadAttentionAttrs attrs = make_attrs(/*bias=*/true); + + tl::expected, std::string> result = + get_attention_incoming_tensor_roles(attrs); + tl::expected, std::string> correct = + std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " "TensorShape, TensorShape)") { int embed_dim = 32; diff --git a/lib/op-attrs/test/src/ops/batch_matmul.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc similarity index 100% rename from lib/op-attrs/test/src/ops/batch_matmul.cc rename to lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc new file mode 100644 index 0000000000..df436da66c --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -0,0 +1,15 @@ +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("BatchNormAttrs to/from json") { + BatchNormAttrs correct = BatchNormAttrs{true}; + + nlohmann::json j = correct; + BatchNormAttrs result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/op-attrs/ops/cast.cc similarity index 100% rename from lib/op-attrs/test/src/ops/cast.cc rename to lib/op-attrs/test/src/op-attrs/ops/cast.cc diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/op-attrs/ops/combine.cc similarity index 100% rename from lib/op-attrs/test/src/ops/combine.cc rename to lib/op-attrs/test/src/op-attrs/ops/combine.cc diff --git a/lib/op-attrs/test/src/ops/conv_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc similarity index 85% rename from lib/op-attrs/test/src/ops/conv_2d.cc rename to lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc index c4462eb7ec..152df09eca 100644 --- a/lib/op-attrs/test/src/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc @@ -5,6 +5,48 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_conv2d_incoming_tensor_roles(Conv2DAttrs") { + auto make_attrs = [](bool use_bias) { + return Conv2DAttrs{/*out_channels=*/4, + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*groups=*/1, + /*activation=*/std::nullopt, + /*use_bias=*/use_bias}; + }; + + SUBCASE("with bias") { + Conv2DAttrs attrs = make_attrs(/*use_bias=*/true); + + std::vector result = + get_conv2d_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("without bias") { + Conv2DAttrs attrs = make_attrs(/*use_bias=*/false); + + std::vector result = + get_conv2d_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("Conv2D shape inference") { int out_channels = 4; int kernel_h = 3; diff --git a/lib/op-attrs/test/src/ops/element_binary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc similarity index 100% rename from lib/op-attrs/test/src/ops/element_binary.cc rename to lib/op-attrs/test/src/op-attrs/ops/element_binary.cc diff --git a/lib/op-attrs/test/src/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc similarity index 100% rename from lib/op-attrs/test/src/ops/element_unary.cc rename to lib/op-attrs/test/src/op-attrs/ops/element_unary.cc diff --git a/lib/op-attrs/test/src/ops/embedding.cc b/lib/op-attrs/test/src/op-attrs/ops/embedding.cc similarity index 100% rename from lib/op-attrs/test/src/ops/embedding.cc rename to lib/op-attrs/test/src/op-attrs/ops/embedding.cc diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index cbcebdbce1..f45ea91dac 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -8,6 +8,42 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_layer_norm_incoming_tensor_roles(LayerNormAttrs)") { + auto make_attrs = [](bool elementwise_affine) { + return LayerNormAttrs{ + /*axes=*/{ff_dim_t{0}, ff_dim_t{2}}, + elementwise_affine, + /*eps=*/1.0, + }; + }; + + SUBCASE("elementwise_affine = true") { + LayerNormAttrs attrs = make_attrs(/*elementwise_affine=*/true); + + std::vector result = + get_layer_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("elementwise_affine = false") { + LayerNormAttrs attrs = make_attrs(/*elementwise_affine=*/false); + + std::vector result = + get_layer_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("shape inference (LayerNorm)") { LayerNormAttrs attrs_affine_true = LayerNormAttrs{ /*axes=*/{ff_dim_t{1}, ff_dim_t{3}}, diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/op-attrs/ops/linear.cc similarity index 74% rename from lib/op-attrs/test/src/ops/linear.cc rename to lib/op-attrs/test/src/op-attrs/ops/linear.cc index f838ff4285..191515b062 100644 --- a/lib/op-attrs/test/src/ops/linear.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/linear.cc @@ -7,6 +7,45 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_linear_incoming_tensor_roles(LinearAttrs)") { + auto make_attrs = [](bool use_bias) { + return LinearAttrs{ + /*out_channels=*/16, + /*use_bias=*/use_bias, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + }; + + SUBCASE("use_bias = true") { + LinearAttrs attrs = make_attrs(/*use_bias=*/true); + + std::vector result = + get_linear_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("use_bias = false") { + LinearAttrs attrs = make_attrs(/*use_bias=*/false); + + std::vector result = + get_linear_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("Linear shape inference") { int out_channels = 16; LinearAttrs attrs = LinearAttrs{ @@ -43,7 +82,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - TensorShape kernel = TensorShape{ + TensorShape projection = TensorShape{ TensorDims{ FFOrdered{ in_channels, @@ -72,10 +111,10 @@ TEST_SUITE(FF_TEST_SUITE) { // get_weight_shape { - tl::expected kernel_result = - get_kernel_shape(attrs, input); - tl::expected kernel_correct = kernel; - CHECK(kernel_result == kernel_correct); + tl::expected projection_result = + get_projection_shape(attrs, input); + tl::expected projection_correct = projection; + CHECK(projection_result == projection_correct); } // get_bias_shape @@ -104,12 +143,12 @@ TEST_SUITE(FF_TEST_SUITE) { output, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); }; - auto make_kernel = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - int o_inchannel, - int o_outchannel) { + auto make_projection = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_inchannel, + int o_outchannel) { return lift_to_parallel_with_degrees( - kernel, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); + projection, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); }; auto make_bias = @@ -143,12 +182,13 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, - DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, - 1, - 1); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, + DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, + 1, + 1); CHECK(result == correct); } @@ -184,9 +224,10 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); CHECK(result == correct); } @@ -216,9 +257,10 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/op-attrs/ops/reduction.cc similarity index 100% rename from lib/op-attrs/test/src/ops/reduction.cc rename to lib/op-attrs/test/src/op-attrs/ops/reduction.cc diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/op-attrs/ops/repartition.cc similarity index 100% rename from lib/op-attrs/test/src/ops/repartition.cc rename to lib/op-attrs/test/src/op-attrs/ops/repartition.cc diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/op-attrs/ops/replicate.cc similarity index 100% rename from lib/op-attrs/test/src/ops/replicate.cc rename to lib/op-attrs/test/src/op-attrs/ops/replicate.cc diff --git a/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc new file mode 100644 index 0000000000..ebeaec4d19 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc @@ -0,0 +1,17 @@ +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("PCGOperatorAttrs to/from json") { + PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1}, + /*repartition_degree=*/4, + }}; + nlohmann::json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc similarity index 100% rename from lib/op-attrs/test/src/test_regularizer_attrs.cc rename to lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc deleted file mode 100644 index 20825f5d73..0000000000 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ /dev/null @@ -1,37 +0,0 @@ -#include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include -#include -#include -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("BatchNormAttrs to/from json") { - BatchNormAttrs correct = BatchNormAttrs{true}; - nlohmann::json j = correct; - BatchNormAttrs result = j.get(); - CHECK(result == correct); - } - - TEST_CASE("ComputationGraphAttrs to/from json") { - ComputationGraphOpAttrs correct = - ComputationGraphOpAttrs{BatchNormAttrs{true}}; - nlohmann::json j = correct; - ComputationGraphOpAttrs result = j.get(); - - CHECK(result == correct); - } - - TEST_CASE("PCGOperatorAttrs to/from json") { - PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{1}, - /*repartition_degree=*/4, - }}; - nlohmann::json j = correct; - PCGOperatorAttrs result = j.get(); - - CHECK(result == correct); - } -} diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 499b26af89..f70d9f7404 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "pcg/computation_graph.dtg.h" #include "pcg/computation_graph/computation_graph_edge.dtg.h" #include "pcg/computation_graph/layer_added_result.dtg.h" @@ -31,6 +32,11 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg, std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n); +std::vector get_incoming_inputs(ComputationGraph const &, + layer_guid_t const &); +std::vector get_incoming_weights(ComputationGraph const &, + layer_guid_t const &); + std::unordered_set get_subgraph_incoming_edges(ComputationGraph const &, std::unordered_set const &); diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index a35763cacc..11e591545d 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -231,21 +231,26 @@ struct ComputationGraphBuilder { tensor_guid_t create_input(TensorShape const &, CreateGrad, - std::optional const &maybe_name = std::nullopt); + std::optional const &name = std::nullopt); + tensor_guid_t create_weight( TensorShape const &, CreateGrad create_grad = CreateGrad::YES, std::optional const &initializer = std::nullopt, std::optional sync_type = std::nullopt, std::optional const &name = std::nullopt); + tensor_guid_t + create_weight(TensorAttrs const &, + std::optional const &name = std::nullopt); std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; - std::vector add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); + std::vector + add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); private: TensorShape get_shape(tensor_guid_t const &) const; @@ -255,30 +260,6 @@ struct ComputationGraphBuilder { tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorAttrs const &output); - - std::vector add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); - - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - TensorShape const &output); - - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output); - - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output); - TensorDims get_broadcast_target_dims(std::vector const &); TensorDims get_broadcast_target_dims(std::vector const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 9150681070..d7248afde4 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -22,12 +22,19 @@ ParallelLayerAddedResult std::vector const &output_labels); std::vector - get_layer_inputs(ParallelComputationGraph const &, - parallel_layer_guid_t const &); + get_incoming_tensors(ParallelComputationGraph const &, + parallel_layer_guid_t const &); std::vector get_layer_outputs(ParallelComputationGraph const &, parallel_layer_guid_t const &); +std::vector + get_incoming_inputs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); +std::vector + get_incoming_weights(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 20e947ad58..3a7f67dcf0 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -13,7 +13,7 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t create_input_tensor( ParallelTensorShape const &shape, - bool create_grad = true, + CreateGrad create_grad = CreateGrad::YES, std::optional const &name = std::nullopt); parallel_tensor_guid_t @@ -54,7 +54,8 @@ struct ParallelComputationGraphBuilder { std::optional activation = std::nullopt, bool use_bias = true, DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &projection_initializer = + std::nullopt, std::optional const &bias_initializer = std::nullopt, std::optional const &name = std::nullopt); diff --git a/lib/pcg/include/pcg/strided_rectangle.struct.toml b/lib/pcg/include/pcg/strided_rectangle.struct.toml index 3dfd90e296..577825238d 100644 --- a/lib/pcg/include/pcg/strided_rectangle.struct.toml +++ b/lib/pcg/include/pcg/strided_rectangle.struct.toml @@ -11,7 +11,7 @@ features = [ includes = [ "pcg/strided_rectangle_side.dtg.h", - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index cf4b1496cf..a69e54fd93 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,5 +1,7 @@ #include "pcg/computation_graph.h" #include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_incoming_tensor_roles.h" +#include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" @@ -80,6 +82,45 @@ std::vector get_incoming_tensors(ComputationGraph const &cg, [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } +static std::vector + get_incoming_tensors_with_role(ComputationGraph const &cg, + layer_guid_t const &l, + IncomingTensorRole desired_role) { + ComputationGraphOpAttrs attrs = get_layer_attrs(cg, l).attrs; + + std::vector incoming_tensors = get_incoming_tensors(cg, l); + + std::vector incoming_tensor_roles = + get_incoming_tensor_roles(attrs, incoming_tensors.size()); + + assert(incoming_tensors.size() == incoming_tensor_roles.size()); + + std::vector result = + filtrans(zip(incoming_tensors, incoming_tensor_roles), + [&](std::pair const &p) + -> std::optional { + tensor_guid_t tensor = p.first; + IncomingTensorRole role = p.second; + + if (role == desired_role) { + return tensor; + } else { + return std::nullopt; + } + }); + return result; +} + +std::vector get_incoming_inputs(ComputationGraph const &cg, + layer_guid_t const &l) { + return get_incoming_tensors_with_role(cg, l, IncomingTensorRole::INPUT); +} + +std::vector get_incoming_weights(ComputationGraph const &cg, + layer_guid_t const &l) { + return get_incoming_tensors_with_role(cg, l, IncomingTensorRole::WEIGHT); +} + std::unordered_set get_subgraph_incoming_edges( ComputationGraph const &cg, std::unordered_set const &subgraph_nodes) { diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index e0b6935a6d..a4f61cff98 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -1,5 +1,6 @@ #include "pcg/computation_graph_builder.h" #include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/attention.h" @@ -55,112 +56,63 @@ tensor_guid_t ComputationGraphBuilder::create_input( maybe_name, }; - return this->add_layer(layer_attrs, {}, {}, tensor_attrs); + return get_only(this->add_layer(layer_attrs, {}, {}, {tensor_attrs})); } tensor_guid_t ComputationGraphBuilder::create_weight( - TensorShape const &shape, - CreateGrad create_grad, - std::optional const &initializer, - std::optional param_sync, + TensorAttrs const &tensor_attrs, std::optional const &maybe_name) { - TensorAttrs tensor_attrs = - TensorAttrs{shape, initializer, param_sync, create_grad}; LayerAttrs layer_attrs = LayerAttrs{ ComputationGraphOpAttrs{InputAttrs{}}, maybe_name, }; - return this->add_layer(layer_attrs, {}, {}, tensor_attrs); + return get_only(this->add_layer(layer_attrs, + std::vector{}, + std::vector{}, + {tensor_attrs})); } -std::vector ComputationGraphBuilder::add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - std::vector raw_weight_tensors; - for (auto const &kv : enumerate_vector(weights)) { - int weight_idx = kv.first; - TensorAttrs weight_tensor_attrs = kv.second; - - std::optional weight_name = - transform(layer.name, [&](std::string const &layer_name) { - return fmt::format("{}.weights[{}]", layer_name, weight_idx); - }); - LayerAttrs weight_layer_attrs = LayerAttrs{ - ComputationGraphOpAttrs{WeightAttrs{weight_tensor_attrs.shape}}, - weight_name, - }; - std::vector weight_layer_inputs = {}; - std::vector weight_output_attrs = {weight_tensor_attrs}; - raw_weight_tensors.push_back(get_only(this->computation_graph.raw_graph - .add_node(weight_layer_attrs, - weight_layer_inputs, - weight_output_attrs) - .outputs)); - } +tensor_guid_t ComputationGraphBuilder::create_weight( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &initializer, + std::optional param_sync, + std::optional const &maybe_name) { + TensorAttrs tensor_attrs = + TensorAttrs{shape, initializer, param_sync, create_grad}; - std::vector raw_inputs = transform( - inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); - std::vector raw_outputs = - this->computation_graph.raw_graph - .add_node( - layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs) - .outputs; - return transform(raw_outputs, - [](DataflowOutput const &o) { return tensor_guid_t{o}; }); + return this->create_weight(tensor_attrs, maybe_name); } -tensor_guid_t - ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorAttrs const &output) { - std::vector outputs = {output}; - return get_only(this->add_layer(layer, inputs, weights, outputs)); -} +static void check_incoming_tensor_roles(LayerAttrs const &layer, + int num_inputs, + int num_weights) { + std::vector correct = + get_incoming_tensor_roles(layer.attrs, num_inputs + num_weights); + std::vector current = concat_vectors( + std::vector(num_inputs, IncomingTensorRole::INPUT), + std::vector(num_weights, IncomingTensorRole::WEIGHT)); -std::vector ComputationGraphBuilder::add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - return this->add_layer( - layer, inputs, weights, transform(outputs, make_output_attrs)); + if (correct != current) { + throw mk_runtime_error( + fmt::format("check_incoming_tensor_roles found deviation in incoming " + "tensors: expected {}, received {}", + correct, + current)); + } } -tensor_guid_t ComputationGraphBuilder::add_layer( +std::vector ComputationGraphBuilder::add_layer( LayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, - TensorShape const &output_shape) { - - TensorAttrs output_attrs = make_output_attrs(output_shape); - LayerAddedResult added = - ::FlexFlow::add_layer(this->computation_graph, - layer, - concat_vectors(inputs, weights), - {output_attrs}); - return get_only(added.outputs); -} - -tensor_guid_t - ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - TensorShape const &output_shape) { - - std::vector weights = {}; - return this->add_layer(layer, inputs, weights, output_shape); -} + std::vector const &outputs) { + check_incoming_tensor_roles(layer, inputs.size(), weights.size()); -tensor_guid_t - ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output) { - return get_only(this->add_layer( - layer, inputs, weights, std::vector{output})); + LayerAddedResult added = ::FlexFlow::add_layer( + this->computation_graph, layer, concat_vectors(inputs, weights), outputs); + return added.outputs; } tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, @@ -201,7 +153,8 @@ tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &input, TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t @@ -238,7 +191,8 @@ tensor_guid_t ComputationGraphBuilder::element_unary( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::element_binary( @@ -271,7 +225,8 @@ tensor_guid_t ComputationGraphBuilder::element_binary( TensorShape output_shape = throw_if_unexpected(get_output_shape( attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); - return this->add_layer(layer, {lhs_input, rhs_input}, output_shape); + return get_only(this->add_layer( + layer, {lhs_input, rhs_input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t @@ -461,7 +416,12 @@ tensor_guid_t ComputationGraphBuilder::conv2d( bias_initializer)); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::dropout( @@ -479,7 +439,8 @@ tensor_guid_t ComputationGraphBuilder::dropout( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::embedding( @@ -507,7 +468,10 @@ tensor_guid_t ComputationGraphBuilder::embedding( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, {weight_attrs}, output_shape); + return get_only(this->add_layer(layer, + {input}, + {this->create_weight(weight_attrs)}, + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::gather( @@ -531,41 +495,9 @@ tensor_guid_t ComputationGraphBuilder::gather( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); - return this->add_layer(layer, {input}, output_shape); -} - -/* std::vector - * ComputationGraphBuilder::get_shapes(std::vector const &ts) - * const { */ -/* return transform(ts, [&](tensor_guid_t const &t) { return - * this->get_shape(t); }); */ -/* } */ - -// tensor_guid_t ComputationGraphBuilder::aggregate( -// tensor_guid_t const &gate_preds, -// tensor_guid_t const &gate_assign, -// tensor_guid_t const &true_gate_assign, -// tensor_guid_t const &full_gate_gradients, -// std::vector const &exp_preds, -// int n, -// float lambda_bal, -// std::optional const &maybe_name) { -// AggregateAttrs attrs = {n, lambda_bal}; -// std::string name = maybe_name.value_or(get_default_name(attrs)); - -// LayerAttrs layer = {attrs, name}; -// TensorShape output_shape = get_output_shape(attrs, -// this->get_shape(gate_preds), -// this->get_shape(gate_assign), -// this->get_shape(true_gate_assign), -// this->get_shape(full_gate_gradients), -// this->get_shape(exp_preds)); - -// std::vector inputs = { -// gate_preds, gate_assign, true_gate_assign, full_gate_gradients}; -// extend(inputs, exp_preds); -// return this->add_layer(layer, inputs, {}, output_shape); -// } + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, @@ -579,7 +511,8 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -597,6 +530,20 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( std::optional initializer, std::optional const &maybe_name) { + if (add_bias_kv) { + throw mk_runtime_error( + "ComputationGraphBuilder::multihead_attention received currently " + "unsupported argument add_bias_kv=true. " + "If you need this functionality, please create an issue."); + } + + if (add_zero_attn) { + throw mk_runtime_error( + "ComputationGraphBuilder::multihead_attention received currently " + "unsupported argument add_zero_attn=true. " + "If you need this functionality, please create an issue."); + } + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{embed_dim, num_heads, kdim, @@ -609,24 +556,48 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + TensorShape query_shape = this->get_shape(query); + TensorShape key_shape = this->get_shape(key); + TensorShape value_shape = this->get_shape(value); + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = - throw_if_unexpected(get_output_shape(attrs, - this->get_shape(query), - this->get_shape(key), - this->get_shape(value))); + TensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, query_shape, key_shape, value_shape)); + + std::vector weights; + + TensorShape weights_shape = throw_if_unexpected( + get_weights_shape(attrs, query_shape, key_shape, value_shape)); + weights.push_back(make_weight_attrs(weights_shape, initializer)); + + if (bias) { + TensorShape input_bias_shape = throw_if_unexpected( + get_input_bias_shape(attrs, query_shape, key_shape, value_shape)); + // initializer chosen based on + // https://github.com/pytorch/pytorch/blob/31c4e0d37d8efc37a0697159e5b9121ec34d5141/torch/nn/modules/activation.py#L1120-L1121 + InitializerAttrs input_bias_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; - TensorShape weights_shape = - throw_if_unexpected(get_weights_shape(attrs, - this->get_shape(query), - this->get_shape(key), - this->get_shape(value))); - TensorAttrs weight_attrs = make_weight_attrs(weights_shape, initializer); + weights.push_back( + make_weight_attrs(input_bias_shape, input_bias_initializer)); - return this->add_layer(layer, - std::vector{query, key, value}, - {weight_attrs}, - output_shape); + TensorShape output_bias_shape = throw_if_unexpected( + get_output_bias_shape(attrs, query_shape, key_shape, value_shape)); + // initializer chosen based on + // https://github.com/pytorch/pytorch/blob/31c4e0d37d8efc37a0697159e5b9121ec34d5141/torch/nn/modules/activation.py#L1120-L1121 + InitializerAttrs output_bias_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + + weights.push_back( + make_weight_attrs(output_bias_shape, output_bias_initializer)); + } + + return get_only(this->add_layer( + layer, + {query, key, value}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } TensorDims ComputationGraphBuilder::get_broadcast_target_dims( @@ -676,7 +647,7 @@ tensor_guid_t ComputationGraphBuilder::dense( std::vector weights; TensorShape projection_shape = - throw_if_unexpected(get_kernel_shape(attrs, this->get_shape(input))); + throw_if_unexpected(get_projection_shape(attrs, this->get_shape(input))); tensor_guid_t projection_weights = this->create_weight(projection_shape, @@ -699,7 +670,8 @@ tensor_guid_t ComputationGraphBuilder::dense( weights.push_back(bias_weights); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, {input}, weights, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::layer_norm( @@ -752,7 +724,12 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::softmax( @@ -781,7 +758,8 @@ tensor_guid_t ComputationGraphBuilder::softmax( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 5b178160cd..b04d9d37b3 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,4 +1,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "op-attrs/get_incoming_tensor_roles.h" +#include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -41,8 +43,8 @@ ParallelLayerAddedResult } std::vector - get_layer_inputs(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &l) { + get_incoming_tensors(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { return transform( get_input_values(pcg.raw_graph, l.raw_graph_node), [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); @@ -56,6 +58,48 @@ std::vector [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } +static std::vector + get_incoming_tensors_with_role(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l, + IncomingTensorRole desired_role) { + PCGOperatorAttrs attrs = get_parallel_layer_attrs(pcg, l).op_attrs; + + std::vector incoming_tensors = + get_incoming_tensors(pcg, l); + + std::vector incoming_tensor_roles = + get_incoming_tensor_roles(attrs, incoming_tensors.size()); + + assert(incoming_tensors.size() == incoming_tensor_roles.size()); + + std::vector result = filtrans( + zip(incoming_tensors, incoming_tensor_roles), + [&](std::pair const &p) + -> std::optional { + parallel_tensor_guid_t tensor = p.first; + IncomingTensorRole role = p.second; + + if (role == desired_role) { + return tensor; + } else { + return std::nullopt; + } + }); + return result; +} + +std::vector + get_incoming_inputs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_incoming_tensors_with_role(pcg, l, IncomingTensorRole::INPUT); +} + +std::vector + get_incoming_weights(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_incoming_tensors_with_role(pcg, l, IncomingTensorRole::WEIGHT); +} + parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &g, parallel_tensor_guid_t const &t) { return parallel_layer_guid_t{t.raw_graph_output.node}; diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 8290a2ff94..620dc035fc 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,4 +1,5 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/parallel_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" @@ -35,13 +36,13 @@ ParallelComputationGraphBuilder::ParallelComputationGraphBuilder() parallel_tensor_guid_t ParallelComputationGraphBuilder::create_input_tensor( ParallelTensorShape const &shape, - bool create_grad, + CreateGrad create_grad, std::optional const &name) { ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ /*shape=*/shape, /*sync_type=*/std::nullopt, /*initializer=*/std::nullopt, - /*create_gradients=*/(create_grad ? CreateGrad::YES : CreateGrad::NO), + /*create_gradients=*/create_grad, }; ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ PCGOperatorAttrs{InputAttrs{}}, @@ -205,7 +206,7 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( { ParallelTensorShape kernel_shape = - throw_if_unexpected(get_kernel_shape(attrs, input_shape)); + throw_if_unexpected(get_projection_shape(attrs, input_shape)); weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); } @@ -580,11 +581,32 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::add_weight( return parallel_tensor_guid_t{current_raw_weight_tensor}; } +static void check_incoming_tensor_roles(ParallelLayerAttrs const &layer, + int num_inputs, + int num_weights) { + std::vector correct = + get_incoming_tensor_roles(layer.op_attrs, num_inputs + num_weights); + std::vector current = concat_vectors( + std::vector(num_inputs, IncomingTensorRole::INPUT), + std::vector(num_weights, IncomingTensorRole::WEIGHT)); + + if (correct != current) { + throw mk_runtime_error( + fmt::format("check_incoming_tensor_roles found deviation in incoming " + "tensors: expected {}, received {}", + correct, + current)); + } +} + std::vector ParallelComputationGraphBuilder::add_layer( ParallelLayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, std::vector const &outputs) { + + check_incoming_tensor_roles(layer, inputs.size(), weights.size()); + std::vector raw_weight_tensors; for (auto const &kv : enumerate_vector(weights)) { int weight_idx = kv.first; @@ -603,6 +625,7 @@ std::vector ParallelComputationGraphBuilder::add_layer( transform(inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); + std::vector raw_outputs = this->pcg.raw_graph .add_node( diff --git a/lib/pcg/test/src/pcg/computation_graph.cc b/lib/pcg/test/src/pcg/computation_graph.cc new file mode 100644 index 0000000000..e2ed51b2f1 --- /dev/null +++ b/lib/pcg/test/src/pcg/computation_graph.cc @@ -0,0 +1,206 @@ +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_incoming_inputs(ComputationGraph, layer_guid_t)") { + SUBCASE("layer has no inputs") { + std::string input_name = "input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_name); + + std::vector result = get_incoming_inputs(cg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string layer_name = "my op"; + + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.relu(input, layer_name); + + ComputationGraph cg = b.computation_graph; + + layer_guid_t layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_inputs(cg, layer); + std::vector correct = {input}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string layer_name = "my op"; + + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/layer_name); + + ComputationGraph cg = b.computation_graph; + + layer_guid_t dense_layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_inputs(cg, dense_layer); + std::vector correct = { + input, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_incoming_weights(ComputationGraph, layer_guid_t)") { + SUBCASE("layer has no inputs or weights") { + std::string input_name = "input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_name); + + std::vector result = get_incoming_weights(cg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string layer_name = "my op"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.relu(input, layer_name); + + return b.computation_graph; + }(); + + layer_guid_t layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_weights(cg, layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string layer_name = "my op"; + std::string projection_name = "my projection weight"; + std::string bias_name = "my bias weight"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/layer_name, + /*projection_name=*/projection_name, + /*bias_name=*/bias_name); + + return b.computation_graph; + }(); + + layer_guid_t dense_layer = get_layer_by_name(cg, layer_name); + + layer_guid_t projection_weight_layer = + get_layer_by_name(cg, projection_name); + tensor_guid_t projection_weight = + get_only(get_outgoing_tensors(cg, projection_weight_layer)); + + layer_guid_t bias_weight_layer = get_layer_by_name(cg, bias_name); + tensor_guid_t bias_weight = + get_only(get_outgoing_tensors(cg, bias_weight_layer)); + + std::vector result = get_incoming_weights(cg, dense_layer); + std::vector correct = { + projection_weight, + bias_weight, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 188447da92..77d938e08a 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,4 +1,5 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test/utils/rapidcheck.h" #include "utils/containers/get_only.h" @@ -35,4 +36,230 @@ TEST_SUITE(FF_TEST_SUITE) { // std::vector correct = {layer1, layer2, layer3}; // CHECK(result == correct); } + + TEST_CASE( + "get_incoming_inputs(ParallelComputationGraph, parallel_layer_guid_t)") { + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("layer has no inputs") { + std::string input_name = "my input"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + b.create_input_tensor(input_shape, CreateGrad::YES, input_name); + + return b.pcg; + }(); + + parallel_layer_guid_t input_layer = + get_parallel_layer_by_name(pcg, input_name); + + std::vector result = + get_incoming_inputs(pcg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string my_op_name = "my op"; + + ParallelComputationGraphBuilder b; + + parallel_tensor_guid_t input = + b.create_input_tensor(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/my_op_name); + + ParallelComputationGraph pcg = b.pcg; + + parallel_layer_guid_t my_op_layer = + get_parallel_layer_by_name(pcg, my_op_name); + + std::vector result = + get_incoming_inputs(pcg, my_op_layer); + std::vector correct = {input}; + + CHECK(result == correct); + } + } + + TEST_CASE( + "get_incoming_weights(ParallelComputationGraph, parallel_layer_guid_t)") { + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("layer has no inputs or weights") { + std::string input_name = "my input"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + b.create_input_tensor(input_shape, CreateGrad::YES, input_name); + + return b.pcg; + }(); + + parallel_layer_guid_t input_layer = + get_parallel_layer_by_name(pcg, input_name); + + std::vector result = + get_incoming_weights(pcg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string my_op_name = "my op"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + parallel_tensor_guid_t input = + b.create_input_tensor(input_shape, CreateGrad::YES); + b.relu(input, my_op_name); + + return b.pcg; + }(); + + parallel_layer_guid_t my_op_layer = + get_parallel_layer_by_name(pcg, my_op_name); + + std::vector result = + get_incoming_weights(pcg, my_op_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights, and weights are separate by " + "parallel ops") { + std::string my_op_name = "my op"; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + LinearAttrs op_attrs = LinearAttrs{ + /*out_channels=*/14, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + ParallelLayerAddedResult input_added = [&] { + ParallelLayerAttrs input_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{InputAttrs{}}, + std::nullopt, + }; + ParallelTensorAttrs input_tensor_attrs = + ParallelTensorAttrs{input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + return add_parallel_layer(pcg, input_attrs, {}, {input_tensor_attrs}); + }(); + parallel_tensor_guid_t input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = [&] { + ParallelTensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(op_attrs, input_shape)); + + TensorShape unpar_projection_shape = + get_reduced_shape(projection_weight_shape); + ParallelTensorShape raw_projection_weight_shape = + lift_to_parallel(unpar_projection_shape); + + ParallelLayerAttrs raw_projection_weight_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{unpar_projection_shape}}, + std::nullopt, + }; + ParallelTensorAttrs raw_projection_tensor_attrs = + ParallelTensorAttrs{raw_projection_weight_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + ParallelLayerAddedResult raw_weight_added = + add_parallel_layer(pcg, + raw_projection_weight_attrs, + {}, + {raw_projection_tensor_attrs}); + + ReplicateAttrs replicate_attrs = ReplicateAttrs{/*degree=*/2}; + ParallelLayerAttrs replicate_layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{replicate_attrs}, + std::nullopt, + }; + ParallelTensorAttrs replicated_projection_tensor_attrs = + ParallelTensorAttrs{ + get_output_shape(replicate_attrs, raw_projection_weight_shape), + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + return add_parallel_layer(pcg, + replicate_layer_attrs, + {}, + {replicated_projection_tensor_attrs}); + }(); + parallel_tensor_guid_t projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult my_op_added = [&] { + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(op_attrs, input_shape)); + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{op_attrs}, + std::nullopt, + }; + ParallelTensorAttrs output_tensor_attrs = + ParallelTensorAttrs{output_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + return add_parallel_layer(pcg, + layer_attrs, + {input, projection_weight}, + {output_tensor_attrs}); + }(); + + parallel_layer_guid_t my_op_layer = my_op_added.parallel_layer; + + std::vector result = + get_incoming_weights(pcg, my_op_layer); + std::vector correct = {projection_weight}; + + CHECK(result == correct); + } + } } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index f46f267859..c445085635 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -15,6 +15,12 @@ using namespace ::FlexFlow; +// Stylistically these tests are not great (they're rather complicated +// and hard to read) and should not be used as a model for other FlexFlow +// tests. +// +// Improving them is being tracked in +// https://github.com/flexflow/FlexFlow/issues/1474 TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ParallelComputationGraphBuilder::add") { ParallelComputationGraphBuilder b; @@ -44,9 +50,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t out = b.add(lhs, rhs); parallel_layer_guid_t layer = get_source_layer(out); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {lhs, rhs}; CHECK(result == correct); } @@ -107,9 +113,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t out = b.batch_matmul(a_tensor, b_tensor); parallel_layer_guid_t layer = get_source_layer(out); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {a_tensor, b_tensor}; CHECK(result == correct); } @@ -150,9 +156,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.cast(input, output_datatype); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -260,20 +266,20 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape correct_bias_shape = get_bias_shape(correct_attrs, input_shape); - std::vector conv_inputs = - get_layer_inputs(b.pcg, conv_guid); + std::vector conv_incoming = + get_incoming_tensors(b.pcg, conv_guid); - parallel_tensor_guid_t conv_input = conv_inputs.at(0); + parallel_tensor_guid_t conv_input = conv_incoming.at(0); ParallelTensorShape conv_input_shape = get_parallel_tensor_attrs(b.pcg, conv_input).shape; CHECK(conv_input_shape == input_shape); - parallel_tensor_guid_t conv_kernel = conv_inputs.at(1); + parallel_tensor_guid_t conv_kernel = conv_incoming.at(1); ParallelTensorShape conv_kernel_shape = get_parallel_tensor_attrs(b.pcg, conv_kernel).shape; CHECK(conv_kernel_shape == correct_kernel_shape); - parallel_tensor_guid_t conv_bias = conv_inputs.at(2); + parallel_tensor_guid_t conv_bias = conv_incoming.at(2); ParallelTensorShape conv_bias_shape = get_parallel_tensor_attrs(b.pcg, conv_bias).shape; CHECK(conv_bias_shape == correct_bias_shape); @@ -315,9 +321,9 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == input); CHECK(result.size() == 3); @@ -358,9 +364,9 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == input); CHECK(result.size() == 2); @@ -408,9 +414,9 @@ TEST_SUITE(FF_TEST_SUITE) { b.multihead_attention(query, key, value, embed_dim, num_heads); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == query); CHECK(result.at(1) == key); CHECK(result.at(2) == value); @@ -449,9 +455,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.relu(input); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -488,9 +494,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_partition(input, ff_dim_t{0}, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -527,9 +533,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_combine(input, ff_dim_t{0}, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -566,9 +572,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_replicate(input, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -605,9 +611,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_reduce(input, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } diff --git a/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h index d4d38af228..a5f0cc6fdc 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h +++ b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h @@ -10,64 +10,66 @@ namespace FlexFlow { std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); -std::optional get_attribute(BatchMatmulAttrs const &p, +std::optional get_attribute(BatchMatmulAttrs const &, OperatorAttributeKey); -std::optional get_attribute(BatchNormAttrs const &p, +std::optional get_attribute(BatchNormAttrs const &, OperatorAttributeKey); -std::optional get_attribute(CastAttrs const &p, +std::optional get_attribute(BroadcastAttrs const &, OperatorAttributeKey); -std::optional get_attribute(CombineAttrs const &p, +std::optional get_attribute(CastAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ConcatAttrs const &p, +std::optional get_attribute(CombineAttrs const &, OperatorAttributeKey); -std::optional get_attribute(Conv2DAttrs const &p, +std::optional get_attribute(ConcatAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ElementBinaryAttrs const &p, +std::optional get_attribute(Conv2DAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ElementUnaryAttrs const &p, +std::optional get_attribute(ElementBinaryAttrs const &, OperatorAttributeKey); -std::optional get_attribute(DropoutAttrs const &p, +std::optional get_attribute(ElementUnaryAttrs const &, OperatorAttributeKey); -std::optional get_attribute(EmbeddingAttrs const &p, +std::optional get_attribute(DropoutAttrs const &, OperatorAttributeKey); -std::optional get_attribute(FlatAttrs const &p, +std::optional get_attribute(EmbeddingAttrs const &, OperatorAttributeKey); -std::optional get_attribute(GatherAttrs const &p, +std::optional get_attribute(FlatAttrs const &, OperatorAttributeKey); -std::optional get_attribute(InputAttrs const &p, +std::optional get_attribute(GatherAttrs const &, OperatorAttributeKey); -std::optional get_attribute(LayerNormAttrs const &p, +std::optional get_attribute(InputAttrs const &, OperatorAttributeKey); -std::optional get_attribute(LinearAttrs const &p, +std::optional get_attribute(LayerNormAttrs const &, + OperatorAttributeKey); +std::optional get_attribute(LinearAttrs const &, OperatorAttributeKey); std::optional - get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); + get_attribute(MultiHeadAttentionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(NoopAttrs const &p, +std::optional get_attribute(NoopAttrs const &, OperatorAttributeKey); -std::optional get_attribute(Pool2DAttrs const &p, +std::optional get_attribute(Pool2DAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReduceAttrs const &p, +std::optional get_attribute(ReduceAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReductionAttrs const &p, +std::optional get_attribute(ReductionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(RepartitionAttrs const &p, +std::optional get_attribute(RepartitionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReplicateAttrs const &p, +std::optional get_attribute(ReplicateAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReshapeAttrs const &p, +std::optional get_attribute(ReshapeAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReverseAttrs const &p, +std::optional get_attribute(ReverseAttrs const &, OperatorAttributeKey); -std::optional get_attribute(SplitAttrs const &p, +std::optional get_attribute(SplitAttrs const &, OperatorAttributeKey); -std::optional get_attribute(SoftmaxAttrs const &p, +std::optional get_attribute(SoftmaxAttrs const &, OperatorAttributeKey); -std::optional get_attribute(TopKAttrs const &p, +std::optional get_attribute(TopKAttrs const &, OperatorAttributeKey); -std::optional get_attribute(TransposeAttrs const &p, +std::optional get_attribute(TransposeAttrs const &, OperatorAttributeKey); -// optional get_attribute(FusedParallelOpAttrs const &p, +// optional get_attribute(FusedParallelOpAttrs const &, // OperatorAttributeKey); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index ad36f1bc4b..02a856f59a 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -69,5 +69,8 @@ type = "::FlexFlow::PoolOp" [[values]] type = "::FlexFlow::TensorShape" +[[values]] +type = "::FlexFlow::TensorDims" + [[values]] type = "::FlexFlow::DataType" diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index a18737085a..d5d735ef59 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -26,6 +26,18 @@ std::optional get_attribute(BatchNormAttrs const &p, } } +std::optional get_attribute(BroadcastAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::TARGET_DIMS: + return p.target_dims; + default: + return std::nullopt; + } +} + std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey key) { switch (key) { diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 4f56a76d0d..d9273b4bcf 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -35,7 +35,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::string a_name = "a"; parallel_tensor_guid_t a_tensor = - builder.create_input_tensor(a_shape, /*create_grad=*/true, a_name); + builder.create_input_tensor(a_shape, CreateGrad::YES, a_name); int outDim = 16; std::string x_matmul_name = "x_matmul"; @@ -65,14 +65,14 @@ TEST_SUITE(FF_TEST_SUITE) { get_parallel_layer_by_name(pcg, x_matmul_name); parallel_layer_guid_t y_matmul = get_parallel_layer_by_name(pcg, y_matmul_name); - std::vector x_inputs = - get_layer_inputs(pcg, x_matmul); - REQUIRE(x_inputs.size() == 2); - parallel_tensor_guid_t x_weights = x_inputs.at(1); - std::vector y_inputs = - get_layer_inputs(pcg, y_matmul); - REQUIRE(y_inputs.size() == 2); - parallel_tensor_guid_t y_weights = y_inputs.at(1); + std::vector x_incoming = + get_incoming_tensors(pcg, x_matmul); + REQUIRE(x_incoming.size() == 2); + parallel_tensor_guid_t x_weights = x_incoming.at(1); + std::vector y_incoming = + get_incoming_tensors(pcg, y_matmul); + REQUIRE(y_incoming.size() == 2); + parallel_tensor_guid_t y_weights = y_incoming.at(1); LabelledOpenDataflowGraph g = LabelledOpenDataflowGraph Date: Tue, 17 Sep 2024 11:50:37 -0700 Subject: [PATCH 3/3] Add Inception-v3 model (#1495) * inception v3 initial implementation * Add parallel shape inference for concat and pool2d * Format * Respond to PR comments * Fix model bugs * Update batch norm to match pytorch interface for inception v3 * Finishing touches for inception, re-add relu flag for batchnorm * Format * Document adaptive pool2d formula simplification --------- Co-authored-by: Pietro Max Marsella Co-authored-by: Colin Unger --- .../src/export_model_arch.cc | 6 +- ...ion_graph_series_parallel_decomposition.cc | 11 + lib/local-execution/src/ops/concat.cc | 4 +- .../models/inception_v3/inception_v3.h | 23 + .../inception_v3_config.struct.toml | 23 + .../inception_v3_output.struct.toml | 25 + .../src/models/inception_v3/inception_v3.cc | 750 ++++++++++++++++++ .../src/models/inception_v3/inception_v3.cc | 19 + .../models/{ => transformer}/transformer.cc | 0 .../include/op-attrs/dim_ordered/concat.h | 34 + .../op-attrs/dim_ordered/dim_ordered.h | 6 +- .../dim_ordered/ff_ordered_from_map.h | 29 + .../include/op-attrs/dim_ordered/slice.h | 7 + .../include/op-attrs/ops/batch_norm.h | 35 +- .../op-attrs/ops/batch_norm_attrs.struct.toml | 22 + lib/op-attrs/include/op-attrs/ops/concat.h | 9 +- .../op-attrs/ops/concat_attrs.struct.toml | 4 - lib/op-attrs/include/op-attrs/ops/flat.h | 8 +- .../op-attrs/ops/flat_attrs.struct.toml | 21 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 20 +- .../op-attrs/ops/pool_2d_attrs.struct.toml | 9 +- .../parallel_tensor_dim_degrees.struct.toml | 28 + .../include/op-attrs/parallel_tensor_dims.h | 3 + .../include/op-attrs/parallel_tensor_shape.h | 6 + .../src/op-attrs/dim_ordered/concat.cc | 1 + .../dim_ordered/ff_ordered_from_map.cc | 1 + .../src/op-attrs/get_incoming_tensor_roles.cc | 5 +- .../src/op-attrs/get_output_shapes.cc | 10 +- lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 255 +++++- lib/op-attrs/src/op-attrs/ops/concat.cc | 135 +++- lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 8 +- lib/op-attrs/src/op-attrs/ops/flat.cc | 120 +-- lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 2 +- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 202 ++++- .../src/op-attrs/parallel_tensor_dims.cc | 8 + .../src/op-attrs/parallel_tensor_shape.cc | 13 + .../op-attrs/computation_graph_op_attrs.cc | 8 +- .../test/src/op-attrs/dim_ordered/concat.cc | 66 ++ .../dim_ordered/ff_ordered_from_map.cc | 66 ++ .../src/op-attrs/get_incoming_tensor_roles.cc | 2 +- .../test/src/op-attrs/ops/batch_norm.cc | 404 ++++++++++ .../test/src/op-attrs/ops/batch_norm_attrs.cc | 7 +- lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc | 8 +- lib/op-attrs/test/src/op-attrs/ops/flat.cc | 244 ++++++ lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc | 400 ++++++++++ .../include/pcg/computation_graph_builder.h | 24 +- .../parallel_computation_graph_builder.h | 5 +- lib/pcg/src/pcg/computation_graph_builder.cc | 166 +++- .../parallel_computation_graph_builder.cc | 44 +- .../operator_attribute_key.enum.toml | 3 +- .../operator_pattern/get_attribute.cc | 12 +- .../test/src/substitutions/substitution.cc | 2 +- .../include/utils/containers/are_all_same.h | 23 + .../utils/containers/require_all_same1.h | 31 + lib/utils/include/utils/containers/subvec.h | 5 + lib/utils/include/utils/containers/sum.h | 17 + lib/utils/include/utils/optional.h | 4 +- .../src/utils/containers/are_all_same.cc | 1 + .../src/utils/containers/require_all_same1.cc | 1 + lib/utils/src/utils/containers/sum.cc | 1 + .../test/src/utils/containers/are_all_same.cc | 36 + .../src/utils/containers/require_all_same1.cc | 54 ++ lib/utils/test/src/utils/containers/sum.cc | 27 + 63 files changed, 3354 insertions(+), 169 deletions(-) create mode 100644 lib/models/include/models/inception_v3/inception_v3.h create mode 100644 lib/models/include/models/inception_v3/inception_v3_config.struct.toml create mode 100644 lib/models/include/models/inception_v3/inception_v3_output.struct.toml create mode 100644 lib/models/src/models/inception_v3/inception_v3.cc create mode 100644 lib/models/test/src/models/inception_v3/inception_v3.cc rename lib/models/test/src/models/{ => transformer}/transformer.cc (100%) create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/concat.h create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/concat.cc create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc create mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc create mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/flat.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc create mode 100644 lib/utils/include/utils/containers/are_all_same.h create mode 100644 lib/utils/include/utils/containers/require_all_same1.h create mode 100644 lib/utils/include/utils/containers/sum.h create mode 100644 lib/utils/src/utils/containers/are_all_same.cc create mode 100644 lib/utils/src/utils/containers/require_all_same1.cc create mode 100644 lib/utils/src/utils/containers/sum.cc create mode 100644 lib/utils/test/src/utils/containers/are_all_same.cc create mode 100644 lib/utils/test/src/utils/containers/require_all_same1.cc create mode 100644 lib/utils/test/src/utils/containers/sum.cc diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index ccc720ed14..98b7a003ce 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,6 +1,7 @@ #include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" #include "op-attrs/computation_graph_op_attrs.h" @@ -59,6 +60,9 @@ tl::expected get_model_computation_graph(std::string const &model_name) { if (model_name == "transformer") { return get_default_transformer_computation_graph(); + } else if (model_name == "inception_v3") { + return get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); @@ -132,7 +136,7 @@ int main(int argc, char **argv) { "for preprocessed to help check series-parallel structure"}); std::vector model_options = { - "transformer", "split_test", "single_operator"}; + "transformer", "inception_v3", "split_test", "single_operator"}; CLIArgumentKey key_model_name = cli_add_positional_argument( cli, CLIPositionalArgumentSpec{ diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index ab537e73de..c9d84a8948 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,5 @@ #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" #include "pcg/computation_graph.h" @@ -291,6 +292,16 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(sp_decomposition.has_value()); } + + SUBCASE("inception_v3") { + ComputationGraph cg = get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } } } diff --git a/lib/local-execution/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc index 35f663b1cd..4c3462e694 100644 --- a/lib/local-execution/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -50,7 +50,7 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto inputs = acc.get_variadic_tensor(INPUTS); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(inputs.size() <= MAX_NUM_INPUTS); return profile(forward_kernel, profiling, @@ -68,7 +68,7 @@ static std::optional auto input_grads = acc.get_variadic_tensor_grad(INPUTS); auto output_grad = acc.get_tensor_grad(OUTPUT); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(input_grads.size() <= MAX_NUM_INPUTS); return profile(backward_kernel, profiling, diff --git a/lib/models/include/models/inception_v3/inception_v3.h b/lib/models/include/models/inception_v3/inception_v3.h new file mode 100644 index 0000000000..5c4754e441 --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 + +#include "models/inception_v3/inception_v3_config.dtg.h" +#include "pcg/computation_graph.dtg.h" + +namespace FlexFlow { + +/** + * @brief Get the default training config from https://arxiv.org/abs/1512.00567. + */ +InceptionV3Config get_default_inception_v3_training_config(); + +/** + * @brief Get a computation graph for Inception-v3 as described in + * https://arxiv.org/abs/1512.00567. + */ +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/inception_v3/inception_v3_config.struct.toml b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml new file mode 100644 index 0000000000..a2a75c83bb --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "InceptionV3Config" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "num_classes" +type = "int" + +[[fields]] +name = "batch_size" +type = "int" + +[[fields]] +name = "aux_logits" +type = "bool" diff --git a/lib/models/include/models/inception_v3/inception_v3_output.struct.toml b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml new file mode 100644 index 0000000000..066e6df02b --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "InceptionV3Output" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "standard_logits" +type = "::FlexFlow::tensor_guid_t" + +[[fields]] +name = "aux_logits" +type = "std::optional<::FlexFlow::tensor_guid_t>" diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc new file mode 100644 index 0000000000..f540eae629 --- /dev/null +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -0,0 +1,750 @@ +#include "models/inception_v3/inception_v3.h" +#include "models/inception_v3/inception_v3_output.dtg.h" +#include "op-attrs/tensor_shape.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +struct CheckShape { + CheckShape(ComputationGraphBuilder const &cgb, + InceptionV3Config const &config) + : cgb(cgb), config(config) {} + + ComputationGraphBuilder const &cgb; + InceptionV3Config const &config; + + void operator()(tensor_guid_t t, int c, int h, int w) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + size_t_from_int(h), + size_t_from_int(w), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); + } + } + + void operator()(tensor_guid_t t, int c) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); + } + } +}; + +InceptionV3Config get_default_inception_v3_training_config() { + return InceptionV3Config{ + /*num_classes=*/1000, + + // see section 8 of https://arxiv.org/abs/1512.00567 for the source of the + // batch size + /*batch_size=*/32, + + // see section 4 of https://arxiv.org/abs/1512.00567 for a discussion of + // auxiliary logits. they are used by default in training + /*aux_logits=*/true, + }; +} + +static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int filters, + int kernel_size_h, + int kernel_size_w, + int stride_h = 1, + int stride_w = 1, + int padding_h = 0, + int padding_w = 0, + bool use_bias = false) { + tensor_guid_t conv = cgb.conv2d(input, + /*outChannels=*/filters, + /*kernelH=*/kernel_size_h, + /*kernelW=*/kernel_size_w, + /*strideH=*/stride_h, + /*strideW=*/stride_w, + /*paddingH=*/padding_h, + /*paddingW=*/padding_w, + /*activation=*/std::nullopt, + /*groups=*/1, + /*use_bias=*/use_bias); + return cgb.batch_norm(conv, + /*affine=*/true, + /*activation=*/Activation::RELU, + /*eps=*/1e-5, + /*momentum=*/0.1); +} + +static tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int pool_features) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch5x5 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/48, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/2, + /*padding_w=*/2); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/pool_features, + /*kernel_stride_h=*/1, + /*kernel_stride_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = create_conv_block(cgb, + input, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_stride_h=*/3, + /*kernel_stride_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch3x3dbl, branch_pool}, /*axis=*/1); +} + +static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + int channels_7x7) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(branch1x1, 192, 17, 17); + + tensor_guid_t branch7x7 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + return t; + }(); + check_shape(branch7x7, 192, 17, 17); + + tensor_guid_t branch7x7dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + return t; + }(); + check_shape(branch7x7dbl, 192, 17, 17); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + check_shape(branch_pool, 192, 17, 17); + + return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, t, 320, 3, 3, 2, 2); + return t; + }(); + + tensor_guid_t branch7x7x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch7x7x3, branch_pool}, /*axis=*/1); +} + +static tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/320, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/448, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input) { + tensor_guid_t t = input; + + check_shape(t, 3, 299, 299); + + // Conv2d_1a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + check_shape(t, 32, 149, 149); + + // Conv2d_2a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 32, 147, 147); + + // Conv2d_2b_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + check_shape(t, 64, 147, 147); + + // maxpool1 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 64, 73, 73); + + // Conv2d_3b_1x1 + t = create_conv_block(cgb, + t, + /*filters=*/80, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(t, 80, 73, 73); + + // Conv2d_4a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 192, 71, 71); + + // maxpool2 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 192, 35, 35); + + return t; +} + +static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + // avgpool + tensor_guid_t x = cgb.pool2d(input, + /*kernelH=*/8, + /*kernelW=*/8, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 2048, 1, 1); + + // dropout + x = cgb.dropout(x, + /*rate=*/0.5); + check_shape(x, 2048, 1, 1); + + x = cgb.flat(x, + /*start_dim=*/1); + check_shape(x, 2048); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + + // softmax (not in pytorch model, but shown in Table 1 on p6 of + // https://arxiv.org/abs/1512.00567) + x = cgb.softmax(x); + check_shape(x, num_classes); + + return x; +} + +static tensor_guid_t create_inception_aux(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + tensor_guid_t x = input; + check_shape(x, 768, 17, 17); + + x = cgb.pool2d(x, + /*kernelH=*/5, + /*kernelW=*/5, + /*strideH=*/3, + /*strideW=*/3, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 768, 5, 5); + + // conv0 + x = create_conv_block(cgb, + x, + /*filters=*/128, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(x, 128, 5, 5); + + // conv1 + x = create_conv_block(cgb, + x, + /*filters=*/768, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5); + check_shape(x, 768, 1, 1); + + x = cgb.adaptive_pool2d(x, + /*output_h=*/1, + /*output_w=*/1); + check_shape(x, 768, 1, 1); + + x = cgb.flat(x, + /*start_dim=*/1); + check_shape(x, 768); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + + return x; +} + +static InceptionV3Output create_inception_v3(ComputationGraphBuilder &cgb, + InceptionV3Config const &config, + tensor_guid_t const &input) { + // NOTE: the shapes for check_shape (as well as the layer names in comments) + // are pulled from + // https://github.com/pytorch/vision/blob/6d7851bd5e2bedc294e40e90532f0e375fcfee04/torchvision/models/inception.py#L103-L155 + CheckShape check_shape = CheckShape{ + /*cgb=*/cgb, + /*config=*/config, + }; + + tensor_guid_t x = create_initial_layers(cgb, check_shape, input); + check_shape(x, 192, 35, 35); + + // Mixed_5b + x = create_inception_module_a(cgb, x, 32); + check_shape(x, 256, 35, 35); + + // Mixed_5c + x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_5d + x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_6a + x = create_inception_module_b(cgb, x); + check_shape(x, 768, 17, 17); + + // Mixed_6b + x = create_inception_module_c(cgb, check_shape, x, 128); + check_shape(x, 768, 17, 17); + + // Mixed_6c + x = create_inception_module_c(cgb, check_shape, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6d + x = create_inception_module_c(cgb, check_shape, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6e + x = create_inception_module_c(cgb, check_shape, x, 192); + check_shape(x, 768, 17, 17); + + std::optional aux; + if (config.aux_logits) { + aux = create_inception_aux(cgb, check_shape, x, config.num_classes); + check_shape(aux.value(), config.num_classes); + } + + // Mixed_7a + x = create_inception_module_d(cgb, x); + check_shape(x, 1280, 8, 8); + + // Mixed_7b + x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + // Mixed_7c + x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + x = create_final_layers(cgb, check_shape, x, config.num_classes); + check_shape(x, config.num_classes); + + return InceptionV3Output{ + x, + aux, + }; +} + +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config) { + ComputationGraphBuilder cgb; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + 3, + 299, + 299, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES); + InceptionV3Output output = create_inception_v3(cgb, config, input); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/test/src/models/inception_v3/inception_v3.cc b/lib/models/test/src/models/inception_v3/inception_v3.cc new file mode 100644 index 0000000000..2b0fe82fd6 --- /dev/null +++ b/lib/models/test/src/models/inception_v3/inception_v3.cc @@ -0,0 +1,19 @@ +#include "models/inception_v3/inception_v3.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_inception_v3_computation_graph") { + InceptionV3Config config = get_default_inception_v3_training_config(); + + ComputationGraph result = get_inception_v3_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 522; + CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer/transformer.cc similarity index 100% rename from lib/models/test/src/models/transformer.cc rename to lib/models/test/src/models/transformer/transformer.cc diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h new file mode 100644 index 0000000000..9b9eaf9b93 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H + +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template +FFOrdered concat(FFOrdered const &l, FFOrdered const &r) { + std::vector l_vec = std::vector(l.cbegin(), l.cend()); + std::vector r_vec = std::vector(r.cbegin(), r.cend()); + + std::vector raw_result = concat_vectors(l_vec, r_vec); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +template +FFOrdered concat(std::vector> const &inputs) { + std::vector> vec_inputs = + transform(inputs, [](FFOrdered const &input) { + return std::vector(input.cbegin(), input.cend()); + }); + + std::vector raw_result = concat_vectors(vec_inputs); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 34d186e74e..6aa23d40fc 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -10,7 +10,7 @@ namespace FlexFlow { template struct DimOrdered { - DimOrdered() = delete; + DimOrdered() {} DimOrdered(std::initializer_list const &l) : contents(l.begin(), l.end()) {} @@ -138,6 +138,10 @@ struct DimOrdered { return this->contents.size(); } + size_t empty() const { + return this->contents.empty(); + } + size_t num_dims() const { return this->size(); } diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h new file mode 100644 index 0000000000..79d4929797 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H + +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "op-attrs/dim_ordered/ff_ordered_of.h" + +namespace FlexFlow { + +template +FFOrdered ff_ordered_from_map(std::map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +template +FFOrdered ff_ordered_from_map(std::unordered_map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index 23b971da6b..e4c0e8e275 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -21,6 +21,13 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; } +template +FFOrdered slice(FFOrdered const &d, + std::optional const &start, + std::optional const &end) { + return nonoverloaded_slice(d, start, end); +} + template DimOrdered slice(DimOrdered const &d, std::optional const &start, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 8afcbb06b1..f2e95690d1 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -1,15 +1,42 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &); +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &); + +tl::expected get_output_shape(BatchNormAttrs const &, + TensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, TensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); + +tl::expected + get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, + ParallelTensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml index bc82f3c743..fdc3bce1fe 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -10,6 +10,28 @@ features = [ "fmt", ] +includes = [ + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + [[fields]] name = "relu" type = "bool" + +[[fields]] +name = "affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" + +[[fields]] +name = "momentum" +type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index f3ac8494c0..f07f06df85 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -10,10 +10,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ConcatAttrs); -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index 4faa870bc4..fab8132993 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -17,7 +17,3 @@ includes = [ [[fields]] name = "axis" type = "::FlexFlow::ff_dim_t" - -[[fields]] -name = "num_inputs" -type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 676d21c59b..710cbdb44b 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -11,8 +12,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(FlatAttrs); TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &); +tl::expected + get_output_parallel_dim_degrees(FlatAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_output_shape(FlatAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml index e445535e29..7349e2a8c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -8,4 +8,23 @@ features = [ "rapidcheck", "fmt", ] -fields = [] + +includes = [ + "", + "op-attrs/ff_dim.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "start_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "end_dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 505fdd9f8c..1af22ad022 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -10,9 +11,22 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &); +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation); + +tl::expected get_output_shape(Pool2DAttrs const &, + TensorShape const &); + +tl::expected + get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(Pool2DAttrs const &, + ParallelTensorDimDegrees const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml index 56bf682f50..20ca7deabc 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -12,6 +12,13 @@ features = [ includes = [ "op-attrs/pool_op.dtg.h", "op-attrs/activation.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] @@ -44,4 +51,4 @@ type = "::FlexFlow::PoolOp" [[fields]] name = "activation" -type = "::FlexFlow::Activation" +type = "std::optional<::FlexFlow::Activation>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml new file mode 100644 index 0000000000..974b27d2a7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ParallelTensorDimDegrees" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/dim_ordered/dim_ordered.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "shard_degrees" +type = "::FlexFlow::FFOrdered" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 8e02e3607b..7a89b4bd78 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_dims.dtg.h" #include "op-attrs/tensor_dims.dtg.h" @@ -14,6 +15,8 @@ std::unordered_set replica_dims(ParallelTensorDims const &); /* size_t get_volume(ParallelTensorDims const &); */ size_t num_shard_dims(ParallelTensorDims const &); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &); + int total_replica_degree(ParallelTensorDims const &); int total_shard_degree(ParallelTensorDims const &); int total_parallel_degree(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 76356b39d4..806a5f0de7 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,6 +1,7 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" @@ -17,12 +18,17 @@ FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &); + ParallelTensorShape lift_to_parallel(TensorShape const &); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &, + ParallelTensorDimDegrees const &); std::unordered_set replica_dims(ParallelTensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..cb29f708a3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/concat.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..2de88f38c8 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc index c7febde1d6..21efc26466 100644 --- a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -1,5 +1,6 @@ #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" @@ -22,8 +23,8 @@ std::vector return std::vector{IncomingTensorRole::INPUT, IncomingTensorRole::INPUT}; }, - [](BatchNormAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; + [](BatchNormAttrs const &attrs) { + return get_batch_norm_incoming_tensor_roles(attrs); }, [](BroadcastAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc index d91d1a1eca..0058ee35a2 100644 --- a/lib/op-attrs/src/op-attrs/get_output_shapes.cc +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -14,6 +14,7 @@ #include "op-attrs/ops/input.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/replicate.h" #include "op-attrs/ops/weight.h" #include "utils/overload.h" @@ -29,7 +30,7 @@ std::vector get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; }, [&](BatchNormAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](CastAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; @@ -38,7 +39,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](ConcatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs)}; + return {throw_if_unexpected(get_output_shape(attrs, inputs))}; }, [&](Conv2DAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; @@ -57,7 +58,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](FlatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](GatherAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0), inputs.at(1))}; @@ -71,6 +72,9 @@ std::vector [&](LinearAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, + [&](Pool2DAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, [&](ReplicateAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; }, diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index b75c3521c6..f394bb8473 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -1,15 +1,260 @@ #include "op-attrs/ops/batch_norm.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/any_of.h" +#include "utils/containers/extend.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, - TensorShape const &input_shape) { +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &attrs) { + std::vector result = {IncomingTensorRole::INPUT}; + + if (attrs.affine) { + extend(result, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return result; +} + +static std::optional + check_input_shape(BatchNormAttrs const &, TensorShape const &input_shape) { + if (num_dims(input_shape) < 2) { + return fmt::format( + "BatchNormAttrs expected input dims >= 2, but received input shape {}", + input_shape); + } + + if (input_shape.data_type != DataType::FLOAT) { + return fmt::format("BatchNormAttrs currently only supports data_type = " + "FLOAT, but received input data_type {}. " + "If you need this feature, please create an issue.", + input_shape.data_type); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + return input_shape; } -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No gamma weights exist for attrs.affine = false"); + } + + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + + return TensorShape{ + TensorDims{FFOrdered{ + num_channels, + }}, + DataType::FLOAT, + }; +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + + if (!attrs.affine) { + return tl::unexpected("No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_shape(attrs, input_shape); +} + +static std::optional + check_input_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.shard_degrees.size() < 2) { + return fmt::format("BatchNormAttrs expected input dims >= 2, but received " + "input degrees {}", + input_degrees); + } + + if (input_degrees.sum_degree != SumDegree{1}) { + return fmt::format("Expected sum degree 1, but receieved sum degree {}", + input_degrees.sum_degree); + } + + if (input_degrees.discard_copy_degree != DiscardCopyDegree{1}) { + return fmt::format( + "Expected discard copy degree 1, but receieved discard copy degree {}", + input_degrees.discard_copy_degree); + } + + FFOrdered non_channel_degrees = + concat(slice(input_degrees.shard_degrees, ff_dim_t{0}, ff_dim_t{1}), + slice(input_degrees.shard_degrees, ff_dim_t{2}, std::nullopt)); + + if (any_of(non_channel_degrees, [](int degree) { return degree != 1; })) { + return fmt::format("Expected parallel degree of all non-channel dimensions " + "to be 1, but received input with degrees {}", + input_degrees); + } + + return std::nullopt; +} + +tl::expected + get_output_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + return input_degrees; +} + +tl::expected + get_gamma_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No gamma weights exist for attrs.affine = false"); + } + + ff_dim_t channel_dim = ff_dim_t{1}; + + return ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{input_degrees.shard_degrees.at(channel_dim)}, + }; +} + +tl::expected + get_beta_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_parallel_dim_degrees(attrs, input_degrees); +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = + get_gamma_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_gamma_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = + get_beta_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_beta_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 02fee70bea..74295f279e 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -1,24 +1,129 @@ #include "op-attrs/ops/concat.h" +#include "op-attrs/dim_ordered/enumerate.h" +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/all_of.h" +#include "utils/containers/are_all_same.h" +#include "utils/containers/require_all_same1.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/fmt/map.h" namespace FlexFlow { -/* bool ConcatAttrs::is_valid( */ -/* std::vector const &input) const { */ -/* bool valid = true; */ -/* for (auto p : input) { */ -/* valid &= p.is_valid(); */ -/* } */ -/* return valid; */ -/* } */ - -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + auto get_non_axis_dims = [&](TensorShape const &s) { + std::map dim_sizes = enumerate(ff_ordered(s.dims)); + dim_sizes.erase(attrs.axis); + return dim_sizes; + }; + + if (inputs.size() <= 1) { + return tl::unexpected(fmt::format("get_output_shape for Concat expected 2 " + "or more input, but receieved {}", + inputs)); + } + + if (attrs.axis.value < 0) { + return tl::unexpected(fmt::format("ConcatAttrs requires axis >= 0")); + } + + if (!are_all_same(transform( + inputs, [](TensorShape const &s) { return num_dims(s); }))) { + return tl::unexpected( + fmt::format("get_output_shape for Concat expected all inputs to have " + "the same number of dimensions, but receieved {}", + inputs)); + } + + std::map non_axis_dims = ({ + tl::expected, std::string> returned = + require_all_same1(transform(inputs, get_non_axis_dims)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + std::vector axis_dim_sizes = transform( + inputs, [&](TensorShape const &s) { return dim_at_idx(s, attrs.axis); }); + + size_t output_axis_dim_size = sum(axis_dim_sizes); + + non_axis_dims.insert({attrs.axis, output_axis_dim_size}); + + DataType datatype = ({ + tl::expected returned = require_all_same1( + transform(inputs, [](TensorShape const &s) { return s.data_type; })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return TensorShape{ + TensorDims{ + ff_ordered_from_map(non_axis_dims), + }, + datatype, + }; } -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, transform(inputs, get_reduced_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + SumDegree sum_degree = ({ + tl::expected returned = + require_all_same1(transform(inputs, get_sum_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + SumDegree{returned.value()}; + }); + + DiscardCopyDegree discard_copy_degree = ({ + tl::expected returned = + require_all_same1(transform(inputs, get_discard_copy_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + DiscardCopyDegree{returned.value()}; + }); + + if (!all_of(inputs, [&](ParallelTensorShape const &s) { + return shard_dim_at_idx(s, attrs.axis).degree == 1; + })) { + return tl::unexpected(fmt::format( + "get_output_shape for Concat expected input tensors to have parallel " + "degree 1 in the concat axis dimension, but received {}", + inputs)); + } + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + require_all_same1(transform(inputs, [](ParallelTensorShape const &s) { + return get_parallel_degrees(s); + })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index f77daf451f..eac756cc15 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -54,11 +54,11 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, Conv2DInputShape input = parse_input_shape(raw_input_shape); size_t out_height = - (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / - attrs.stride_h; + (input.height + (2 * attrs.padding_h) - attrs.kernel_h) / attrs.stride_h + + 1; size_t out_width = - (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / - attrs.stride_w; + (input.width + (2 * attrs.padding_w) - attrs.kernel_w) / attrs.stride_w + + 1; assert(attrs.out_channels > 0); diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 5d318207ee..e9833d5e3f 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -1,57 +1,85 @@ #include "op-attrs/ops/flat.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/containers/any_of.h" +#include "utils/containers/product.h" #include namespace FlexFlow { -TensorShape get_output_shape(FlatAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +TensorShape get_output_shape(FlatAttrs const &attrs, + TensorShape const &input_shape) { + FFOrdered leading_dims = + slice(ff_ordered(input_shape.dims), ff_dim_t{0}, attrs.start_dim); + FFOrdered flattened_dims = + slice(ff_ordered(input_shape.dims), attrs.start_dim, attrs.end_dim); + FFOrdered trailing_dims = + slice(ff_ordered(input_shape.dims), attrs.end_dim, std::nullopt); + + if (flattened_dims.empty()) { + return input_shape; + } + + return TensorShape{ + TensorDims{ + concat(std::vector{ + leading_dims, + {product(flattened_dims)}, + trailing_dims, + }), + }, + input_shape.data_type, + }; } -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_parallel_dim_degrees( + FlatAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + FFOrdered flattened_dim_degrees = + slice(input_degrees.shard_degrees, attrs.start_dim, attrs.end_dim); + + if (flattened_dim_degrees.empty()) { + return input_degrees; + } + + if (any_of(flattened_dim_degrees, [](int degree) { return degree != 1; })) { + return tl::unexpected( + fmt::format("get_output_parallel_dim_degrees for {} expected all shard " + "degrees of flattened dimensions to be 1, but received {}", + attrs, + input_degrees)); + } + + return ParallelTensorDimDegrees{ + /*sum_degree=*/input_degrees.sum_degree, + /*discard_copy_degree=*/input_degrees.discard_copy_degree, + /*shard_degrees=*/ + concat(std::vector{ + slice(input_degrees.shard_degrees, ff_dim_t{0}, attrs.start_dim), + {product(flattened_dim_degrees)}, + slice(input_degrees.shard_degrees, attrs.end_dim, std::nullopt), + }), + }; } -// namespace Input { -// constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, -// REPLICA = 4; -// } -// -// namespace Output { -// constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; -// } -// -/* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ - -/* bool is_valid = true; */ -/* is_valid &= input.is_valid(); */ -/* is_valid &= output_shape.is_valid(); */ -/* is_valid &= (input.at(Input::WIDTH).degree == 1); */ - -/* return is_valid; */ -/* } */ - -/* ParallelTensorShape FlatAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* assert (input.num_dims() == Input::NUMDIM); */ -/* ParallelTensorShape output_dims; */ -/* output_dims.data_type = input.data_type; */ - -/* output_dims.at(Output::REPLICA) = input.at(Input::REPLICA); */ -/* output_dims.at(Output::SAMPLE) = input.at(Input::SAMPLE); */ - -/* output_dims.at(Output::CHANNEL).degree = input.at(Input::CHANNEL).degree; - */ -/* assert (input.at(Input::HEIGHT).degree == 1); */ -/* assert (input.at(Input::WIDTH).degree == 1); */ - -/* output_dims.at(Output::CHANNEL).size = input.at(Input::CHANNEL).size * - * input.at(Input::HEIGHT).size * input.at(Input::WIDTH).size; */ -/* output_dims.at(Output::CHANNEL).parallel_idx = - * input.at(Input::CHANNEL).parallel_idx; */ - -/* return output_dims; */ -/* } */ +tl::expected + get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index d3c00efbb9..0dd9ac7a17 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -112,7 +112,7 @@ static std::optional if (get_discard_copy_degree(input_shape) != 1) { return fmt::format( - "Expected discard copy degree 1, but received discartd copy degree {}", + "Expected discard copy degree 1, but received discard copy degree {}", get_discard_copy_degree(input_shape)); } diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index e1917efd89..95bcd8b336 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,62 +1,184 @@ #include "op-attrs/ops/pool_2d.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation) { + // AdaptivePool2D semantics pulled from + // https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993 -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} + if (num_dims(input_dims) != 4) { + return tl::unexpected( + fmt::format("make_adaptive_pool2d_attrs expected input tensor to " + "have 4 dims, but received dims {}", + input_dims)); + } -} // namespace FlexFlow + size_t num_samples = dim_at_idx(input_dims, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_dims, ff_dim_t{1}); + size_t input_h = dim_at_idx(input_dims, ff_dim_t{2}); + size_t input_w = dim_at_idx(input_dims, ff_dim_t{3}); -/* -#include "op-attrs/ops/pool_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" + if (input_h % output_h != 0) { + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_h % output_h " + "== 0, but received input_h={} and output_h={} (input_dims={}). If you " + "need input_h % output_h != 0 supported, please create an issue.", + input_h, + output_h, + input_dims)); + } -namespace FlexFlow { + if (input_w % output_w != 0) { + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_w % output_w " + "== 0, but received input_w={} and output_w={} (input_dims={}). If you " + "need input_w % output_w != 0 supported, please create an issue.", + input_w, + output_w, + input_dims)); + } -namespace Input { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; + // Note that for some reason the stack overflow post linked above states that + // `kernel_size = ind - (outd-1)*stride`, but some simplification yields + // `kernel_size` = `ind - (outd - 1)*stride` + // = `ind - (outd - 1) * (ind / outd)` + // = `ind - ind + (ind /outd)` + // = `ind / outd` + // = `stride` -namespace Output { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; + int kernel_h = input_h / output_h; + int kernel_w = input_w / output_w; -bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { - ParallelTensorShape output_shape = this->calculate_output_shape(input); + int stride_h = kernel_h; + int stride_w = kernel_w; - return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); -} + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + + TensorShape expected_ouput_shape = TensorShape{ + TensorDims{FFOrdered{ + num_samples, + num_channels, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; -static std::vector - construct_mappings(ParallelTensorShape const &input_shape) { - auto const outputMappings = construct_output_parallel_dims({ - {Input::REPLICA, MappingOperation::PARTITION, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::CHANNEL, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}, + TensorShape output_shape = ({ + tl::expected result = + get_output_shape(attrs, TensorShape{input_dims, DataType::FLOAT}); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); }); - return outputMappings; + if (output_shape != expected_ouput_shape) { + return tl::unexpected( + fmt::format("Result of make_adaptive_pool_2d (i.e., {}) should produce " + "expected output shape {}, but produced {}. This is a bug " + "in FlexFlow, Please create an issue.", + attrs, + expected_ouput_shape, + output_shape)); + } + + return attrs; } -static ParallelDimMappingSolution - solve_mappings(ParallelTensorShape const &input) { - return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); +tl::expected + get_output_shape(Pool2DAttrs const &attrs, TensorShape const &input_shape) { + if (num_dims(input_shape) != 4) { + return tl::unexpected( + fmt::format("get_output_shape for Pool2DAttrs expected input tensor to " + "have 4 dims, but received shape {}", + input_shape)); + } + + size_t num_samples = dim_at_idx(input_shape, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + size_t input_height = dim_at_idx(input_shape, ff_dim_t{2}); + size_t input_width = dim_at_idx(input_shape, ff_dim_t{3}); + + size_t output_height = + (input_height + 2 * attrs.padding_h - attrs.kernel_h) / attrs.stride_h + + 1; + + size_t output_width = + (input_width + 2 * attrs.padding_w - attrs.kernel_w) / attrs.stride_w + 1; + + return TensorShape{TensorDims{FFOrdered{ + num_samples, + num_channels, + output_height, + output_width, + }}, + input_shape.data_type}; } -ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape -const &input) const { return solve_mappings(input).output_shapes.at(0); +tl::expected + get_output_shape(Pool2DAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected result_degrees = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!result_degrees.has_value()) { + return tl::unexpected(result_degrees.error()); + } + result_degrees.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_output_parallel_dim_degrees( + Pool2DAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.sum_degree.value > 1) { + if (attrs.pool_type == PoolOp::MAX) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with PoolOp::MAX " + "expected input sum degree == 1, but received {}", + input_degrees)); + } else if (attrs.activation.has_value()) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with activation={} " + "expected input sum degree == 1, but received {}", + attrs.activation.value(), + input_degrees)); + } + } + + return input_degrees; } } // namespace FlexFlow -*/ diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 4bce5449f4..61062b84b0 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -29,6 +29,14 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { return dims.shard_dims.size(); } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { + return ParallelTensorDimDegrees{ + d.replica_dims.sum_degree, + d.replica_dims.discard_copy_degree, + ff_ordered_shard_degrees(d), + }; +} + int total_replica_degree(ParallelTensorDims const &dims) { return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 10bf5027a4..3cd0f47a5d 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -59,6 +59,10 @@ std::optional } } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &s) { + return get_parallel_degrees(s.dims); +} + ParallelTensorShape lift_to_parallel(TensorShape const &s) { return ParallelTensorShape{lift_to_parallel(s.dims), s.data_type}; } @@ -75,6 +79,15 @@ ParallelTensorShape }; } +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &s, + ParallelTensorDimDegrees const °rees) { + return lift_to_parallel_with_degrees(s, + degrees.sum_degree, + degrees.discard_copy_degree, + degrees.shard_degrees); +} + TensorShape require_not_parallel(ParallelTensorShape const &s) { int total_degree = get_total_parallel_degree(s); if (total_degree != 1) { diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc index 42ea07e6b5..84f1861f0b 100644 --- a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -5,8 +5,12 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphOpAttrs to/from json") { - ComputationGraphOpAttrs correct = - ComputationGraphOpAttrs{BatchNormAttrs{true}}; + ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }}; nlohmann::json j = correct; auto result = j.get(); diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..2ac641cfc2 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1,66 @@ +#include "op-attrs/dim_ordered/concat.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("concat(FFOrdered, FFOrdered)") { + SUBCASE("inputs have elements") { + FFOrdered l_input = FFOrdered{1, 3, 1}; + FFOrdered r_input = FFOrdered{2, 1}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {1, 3, 1, 2, 1}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + FFOrdered l_input = FFOrdered{}; + FFOrdered r_input = FFOrdered{}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } + + TEST_CASE("concat(std::vector>)") { + SUBCASE("inputs have elements") { + std::vector> input = { + {1}, + {2, 1}, + {1}, + }; + + FFOrdered result = concat(input); + FFOrdered correct = { + 1, + 2, + 1, + 1, + }; + + CHECK(result == correct); + } + + SUBCASE("no inputs") { + std::vector> input = {}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + std::vector> input = {{}, {}, {}}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..7bc1695e5c --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1,66 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("ff_ordered_from_map", + T, + std::map, + std::unordered_map) { + SUBCASE("input is empty") { + T m = {}; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is missing keys") { + SUBCASE("missing key is in middle") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 2}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("missing key is 0 idx") { + T m = { + {ff_dim_t{1}, 2}, + {ff_dim_t{2}, 7}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + } + + SUBCASE("input has negative keys") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{-1}, 2}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("input is valid") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{2}, 2}, + {ff_dim_t{3}, 7}, + }; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {4, 5, 2, 7}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc index 60dedfe70a..33cc00c6a1 100644 --- a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -9,7 +9,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Concat") { int num_incoming = 4; ComputationGraphOpAttrs attrs = - ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}, num_incoming}}; + ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}}}; std::vector result = get_incoming_tensor_roles(attrs, num_incoming); diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc new file mode 100644 index 0000000000..4196394d00 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc @@ -0,0 +1,404 @@ +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_batch_norm_incoming_tensor_roles(BatchNormAttrs)") { + auto make_attrs = [](bool affine) { + return BatchNormAttrs{ + /*relu=*/false, + /*affine=*/affine, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + }; + + SUBCASE("affine = true") { + BatchNormAttrs attrs = make_attrs(/*affine=*/true); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + BatchNormAttrs attrs = make_attrs(/*affine=*/false); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("shape inference (BatchNorm)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + 18, + }}, + DataType::FLOAT, + }; + + TensorShape output = input; + + TensorShape gamma = TensorShape{ + TensorDims{FFOrdered{ + 14, + }}, + DataType::FLOAT, + }; + + TensorShape beta = gamma; + + SUBCASE("get_output_shape(BatchNormAttrs, TensorShape)") { + tl::expected result = + get_output_shape(attrs_affine_true, input); + tl::expected correct = output; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_shape(attrs_affine_true, input); + tl::expected correct = gamma; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_shape(attrs_affine_true, input); + tl::expected correct = beta; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel dim degree inference (BatchNormAttrs)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + SUBCASE("partition parallelism (in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + degree, + 1, + 1, + }, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + tl::expected result = + get_output_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = + optional_from_expected(get_gamma_weights_parallel_dim_degrees( + attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = + optional_from_expected(get_beta_weights_parallel_dim_degrees( + attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + SUBCASE("partition parallelism (not in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, degree, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("sum parallelism") { + SumDegree sum_degree = SumDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + sum_degree, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy parallelism") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + discard_copy_degree, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel shape inference (BatchNormAttrs)") { + // since most of the edge cases are already tested in the above test cases + // (i.e., shape inference and parallel degree inference) + // here we just do a basic check that they compose + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/true, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{18, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("get_output_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_gamma_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_beta_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc index df436da66c..3d86576279 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -5,7 +5,12 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { - BatchNormAttrs correct = BatchNormAttrs{true}; + BatchNormAttrs correct = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }; nlohmann::json j = correct; BatchNormAttrs result = j.get(); diff --git a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc index 152df09eca..7abb98f3e3 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc @@ -73,8 +73,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; size_t num_samples = 7; - size_t input_channels = 6; - size_t input_height = 10; + size_t input_channels = 4; + size_t input_height = 11; size_t input_width = 15; TensorShape input = TensorShape{ @@ -87,8 +87,8 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - size_t output_height = 3; - size_t output_width = 6; + size_t output_height = 6; + size_t output_width = 8; TensorShape output = TensorShape{ TensorDims{FFOrdered{ diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc new file mode 100644 index 0000000000..d81ab95c35 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -0,0 +1,244 @@ +#include "op-attrs/ops/flat.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(FlatAttrs, TensorShape)") { + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + SUBCASE("flatten all dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4 * 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten trailing dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten leading dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten middle dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4 * 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim == end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim < end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{1}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + } + + TEST_CASE( + "get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees)") { + FlatAttrs attrs = FlatAttrs{/*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}}; + + SUBCASE("allows shard parallelism in non-flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 1, 3}, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 3}, + }; + + CHECK(result == correct); + } + + SUBCASE("does not allow shard parallelism in flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 2, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("allows sum parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(FlatAttrs, ParallelTensorShape)") { + // since most of the edge cases are already tested in + // get_output_shape(FlatAttrs, TensorShape) and + // get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees), + // here we just do a basic check that they compose + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8, 1}, + ShardParallelDim{6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + tl::expected result = + get_output_shape(attrs, input_shape); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8 * 6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc new file mode 100644 index 0000000000..0c14c0fc2a --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc @@ -0,0 +1,400 @@ +#include "op-attrs/ops/pool_2d.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("make_adaptive_pool2d") { + size_t input_n = 10; + size_t input_c = 11; + size_t input_h = 15; + size_t input_w = 20; + Activation activation = Activation::RELU; + PoolOp op = PoolOp::AVG; + + TensorDims input_dims = + TensorDims{FFOrdered{input_n, input_c, input_h, input_w}}; + + SUBCASE("input_h divisible by output_h && input_w divisible by output_w") { + int output_h = 5; + int output_w = 2; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/10, + /*stride_h=*/3, + /*stride_w=*/10, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = + get_output_shape(correct_attrs, input_shape); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + input_n, + input_c, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + SUBCASE("input_h not divisible by output_h") { + int output_h = 6; + int output_w = 2; + + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_w not divisible by output_w") { + int output_h = 5; + int output_w = 3; + + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_h == output_h and input_w == output_w") { + int output_h = input_h; + int output_w = input_w; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/1, + /*kernel_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = + get_output_shape(correct_attrs, input_shape); + tl::expected correct = input_shape; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, TensorShape)") { + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("fails on non-4d inputs") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + 14, + }}, + DataType::FLOAT, + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("4d input") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{11, 13, 12, 6}}, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{11, 13, 6, 4}}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_parallel_dim_degrees(Pool2DAttrs, " + "ParallelTensorDimDegrees)") { + auto make_attrs = [](PoolOp pool_type, + std::optional const &activation) { + return Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + }; + + SUBCASE("allows data parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows arbitrary input sharding parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 2, + 5, + 6, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{3}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("sum parallelism") { + SUBCASE("without activation") { + SUBCASE("PoolOp::MAX does not allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + std::optional result = + optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("PoolOp::AVG does allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + } + + SUBCASE("with activation does not allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/Activation::RELU); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, ParallelTensorShape)") { + // this function is mostly covered by the tests above, so we + // just do a single test to make sure it works/exists + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("valid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{12, 3}, + ShardParallelDim{6, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{6, 3}, + ShardParallelDim{4, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + } + + SUBCASE("invalid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 1}, + ShardParallelDim{16, 1}, + ShardParallelDim{12, 1}, + ShardParallelDim{6, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 11e591545d..45cde0de57 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -137,6 +137,13 @@ struct ComputationGraphBuilder { PoolOp type = PoolOp::MAX, std::optional const &activation = std::nullopt, std::optional const &name = std::nullopt); + tensor_guid_t adaptive_pool2d( + tensor_guid_t const &input, + int output_h, + int output_w, + PoolOp type = PoolOp::MAX, + std::optional const &activation = std::nullopt, + std::optional const &name = std::nullopt); tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, @@ -145,7 +152,10 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); tensor_guid_t batch_norm(tensor_guid_t const &input, - bool relu = true, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); tensor_guid_t batch_matmul(tensor_guid_t const &A, @@ -170,11 +180,9 @@ struct ComputationGraphBuilder { DataType dtype, std::optional const &name = std::nullopt); // Add a concat layer - tensor_guid_t - concat(int n, - std::vector const &tensors, - int axis, - std::optional const &maybe_name = std::nullopt); + tensor_guid_t concat(std::vector const &tensors, + int axis, + std::optional const &name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, @@ -188,6 +196,8 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); // Add a flat layer tensor_guid_t flat(tensor_guid_t const &input, + int start_dim = 0, + std::optional const &end_dim = std::nullopt, std::optional const &name = std::nullopt); // Add a softmax layer tensor_guid_t softmax(tensor_guid_t const &input, @@ -252,9 +262,9 @@ struct ComputationGraphBuilder { std::vector const &weights, std::vector const &outputs); -private: TensorShape get_shape(tensor_guid_t const &) const; +private: tensor_guid_t broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 3a7f67dcf0..019b120936 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -87,7 +87,10 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t batch_norm(parallel_tensor_guid_t const &input, - bool relu = true, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); parallel_tensor_guid_t diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index a4f61cff98..4a565476bd 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -6,14 +6,17 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/concat.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/dropout.h" #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/flat.h" #include "op-attrs/ops/gather.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/tensor_dims.h" @@ -498,21 +501,130 @@ tensor_guid_t ComputationGraphBuilder::gather( return get_only( this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } +tensor_guid_t ComputationGraphBuilder::pool2d( + tensor_guid_t const &x, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernelH, + /*kernel_w=*/kernelW, + /*stride_h=*/strideH, + /*stride_w=*/strideW, + /*padding_h=*/paddingH, + /*padding_w=*/paddingW, + /*pool_type=*/type, + /*activation=*/activation, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::adaptive_pool2d( + tensor_guid_t const &uncasted_input, + int output_h, + int output_w, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + TensorDims input_dims = this->get_shape(uncasted_input).dims; + + Pool2DAttrs attrs = throw_if_unexpected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, type, activation)); + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t casted_input = + this->as_type(uncasted_input, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, this->get_shape(casted_input))); + + return get_only(this->add_layer( + layer, {casted_input}, {}, {make_output_attrs(output_shape)})); +} tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, - bool relu, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); + } + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, + }; + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape input_shape = this->get_shape(input); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); - return get_only( - this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); + std::vector weights; + + if (affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + TensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + TensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } + + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -674,6 +786,50 @@ tensor_guid_t ComputationGraphBuilder::dense( layer, {input}, weights, {make_output_attrs(output_shape)})); } +tensor_guid_t ComputationGraphBuilder::concat( + std::vector const &inputs, + int axis, + std::optional const &maybe_name) { + + ConcatAttrs attrs = ConcatAttrs{ff_dim_t{axis}}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + std::vector input_shapes = transform( + inputs, [&](tensor_guid_t const &i) { return this->get_shape(i); }); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shapes)); + + return get_only( + this->add_layer(layer, inputs, {}, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::flat( + tensor_guid_t const &input, + int start_dim, + std::optional const &end_dim, + std::optional const &maybe_name) { + int input_num_dims = num_dims(this->get_shape(input)); + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{start_dim}, + /*end_dim=*/ff_dim_t{end_dim.value_or(input_num_dims)}, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} + tensor_guid_t ComputationGraphBuilder::layer_norm( tensor_guid_t const &input, std::vector const &axes, diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 620dc035fc..ce00ea62f4 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -331,18 +331,56 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( parallel_tensor_guid_t const &input, - bool relu, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); + } + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, + }; std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = - get_output_shape(attrs, this->get_shape(input)); + throw_if_unexpected(get_output_shape(attrs, input_shape)); + + std::vector weights; + + if (attrs.affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + ParallelTensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + ParallelTensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } return this->add_layer(layer, {input}, {}, {output_shape}); } diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml index 59e913750e..eb758ea4fc 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml @@ -55,7 +55,8 @@ values = [ { name = "SHOULD_BROADCAST_LHS" }, { name = "SHOULD_BROADCAST_RHS" }, { name = "DIM" }, - { name = "ELEMENTWISE_AFFINE" }, + { name = "AFFINE" }, + { name = "MOMENTUM" }, { name = "REGULARIZER" }, { name = "SHAPE" }, { name = "SPLITS" }, diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index d5d735ef59..442d3345a1 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -19,8 +19,12 @@ std::optional get_attribute(BatchNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); - case OperatorAttributeKey::RELU: - return p.relu; + case OperatorAttributeKey::EPSILON: + return p.eps; + case OperatorAttributeKey::AFFINE: + return p.affine; + case OperatorAttributeKey::MOMENTUM: + return p.momentum; default: return std::nullopt; } @@ -189,6 +193,10 @@ std::optional get_attribute(LayerNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); + case OperatorAttributeKey::AFFINE: + return p.elementwise_affine; + case OperatorAttributeKey::AXES: + return vector_of(p.axes); default: return std::nullopt; } diff --git a/lib/substitutions/test/src/substitutions/substitution.cc b/lib/substitutions/test/src/substitutions/substitution.cc index 87ffc01f0b..1718b03b5c 100644 --- a/lib/substitutions/test/src/substitutions/substitution.cc +++ b/lib/substitutions/test/src/substitutions/substitution.cc @@ -21,7 +21,7 @@ TEST_SUITE(FF_TEST_SUITE) { // } TEST_CASE("evaluate_substitution_output(SubParallelComputationGraph, " - "Substituion, PCGPatternMatch)") { + "Substitution, PCGPatternMatch)") { // Currently Substitution creation is very verbose. // This is being addressed in // https://github.com/flexflow/FlexFlow/issues/1473. diff --git a/lib/utils/include/utils/containers/are_all_same.h b/lib/utils/include/utils/containers/are_all_same.h new file mode 100644 index 0000000000..37b1838146 --- /dev/null +++ b/lib/utils/include/utils/containers/are_all_same.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H + +namespace FlexFlow { + +template +bool are_all_same(C const &c) { + if (c.empty()) { + return true; + } + + auto const &first = *c.cbegin(); + for (auto const &v : c) { + if (v != first) { + return false; + } + } + return true; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_all_same1.h b/lib/utils/include/utils/containers/require_all_same1.h new file mode 100644 index 0000000000..2f42243857 --- /dev/null +++ b/lib/utils/include/utils/containers/require_all_same1.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H + +#include +#include + +namespace FlexFlow { + +template +tl::expected require_all_same1(C const &c) { + if (c.empty()) { + return tl::unexpected(fmt::format( + "require_all_same1 expected non-empty container, but received {}", c)); + } + + T const &first = *c.cbegin(); + for (T const &v : c) { + if (v != first) { + return tl::unexpected(fmt::format("require_all_same1 found non-same " + "elements {} and {} in containers {}", + first, + v, + c)); + } + } + return first; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 52368f94ad..5ae90ec5ba 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -25,10 +25,15 @@ std::vector subvec(std::vector const &v, if (maybe_start.has_value()) { begin_iter += resolve_loc(maybe_start.value()); } + if (maybe_end.has_value()) { end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } + if (end_iter < begin_iter) { + end_iter = begin_iter; + } + std::vector output(begin_iter, end_iter); return output; } diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h new file mode 100644 index 0000000000..5dbd620781 --- /dev/null +++ b/lib/utils/include/utils/containers/sum.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H + +namespace FlexFlow { + +template +T sum(C const &c) { + T result = 0; + for (T const &t : c) { + result += t; + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3ec165d595..377561d70c 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H #include "utils/exception.h" #include "utils/fmt/optional.h" diff --git a/lib/utils/src/utils/containers/are_all_same.cc b/lib/utils/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..c515bceee2 --- /dev/null +++ b/lib/utils/src/utils/containers/are_all_same.cc @@ -0,0 +1 @@ +#include "utils/containers/are_all_same.h" diff --git a/lib/utils/src/utils/containers/require_all_same1.cc b/lib/utils/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..295339a91d --- /dev/null +++ b/lib/utils/src/utils/containers/require_all_same1.cc @@ -0,0 +1 @@ +#include "utils/containers/require_all_same1.h" diff --git a/lib/utils/src/utils/containers/sum.cc b/lib/utils/src/utils/containers/sum.cc new file mode 100644 index 0000000000..088b5f1983 --- /dev/null +++ b/lib/utils/src/utils/containers/sum.cc @@ -0,0 +1 @@ +#include "utils/containers/sum.h" diff --git a/lib/utils/test/src/utils/containers/are_all_same.cc b/lib/utils/test/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..fd8b321439 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_all_same.cc @@ -0,0 +1,36 @@ +#include "utils/containers/are_all_same.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_all_same(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("input elements are all same") { + std::vector input = {1, 1, 1}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all same") { + std::vector input = {1, 1, 2, 1}; + + bool result = are_all_same(input); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/require_all_same1.cc b/lib/utils/test/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..48c1ab0b99 --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_all_same1.cc @@ -0,0 +1,54 @@ +#include "utils/containers/require_all_same1.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/expected.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("require_all_same1(T)", + T, + std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { + SUBCASE("input is empty") { + T input = {}; + + std::optional result = + optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input elements are all the same") { + T input = {1, 1, 1}; + + tl::expected result = require_all_same1(input); + tl::expected correct = 1; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all the same") { + T input = {1, 1, 2, 1}; + + std::optional result = + optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/sum.cc b/lib/utils/test/src/utils/containers/sum.cc new file mode 100644 index 0000000000..32d8cd32a3 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sum.cc @@ -0,0 +1,27 @@ +#include "utils/containers/sum.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sum(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + int result = sum(input); + int correct = 0; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::vector input = {1, 3, 2}; + + int result = sum(input); + int correct = 6; + + CHECK(result == correct); + } + } +}