From d03f18799c42fd9b56be64a728b652a471d2dcb1 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Mon, 24 Jan 2022 15:10:58 -0800 Subject: [PATCH] ** Collage v2 sketch *** - Enable cudnn, get rid of support for op-predicate based BYOC integrations - Enable cublas - And yet another go at pruning unnecessary candidates. - Another go at pruning unnecessary candidates - Fix CompositePartitionRule use - Fix a few bugs with new TensorRT pattern-based integration - Rework RemoveSubCandidatesCombinerRule for soundness - Better logging - Bug fixes - Implement critical nodes idea for avoiding obviously unnecessary candidates - Promote DataflowGraph from alias to class so can cache downstream index set - Quick check to avoid unioning candidates which would create a cycle - Host out CandidatePartitionIndex and add rules to avoid small candidates subsumed by containing candidates - GetFunction can legitimately return nullptr - rename tuning log - Support for int64 literals - Switch GPT2 to plain model - Fix library cloberring issue for cutlass - actually checkin 'built in' tuning log (covers mnist & gpt2 only) - trying to debug gpt2 - Update TargetKind attribute name - working through gpt2 issues - checkin tuning records for MNIST (with hack to not retry failed winograd) - Autotvm tuning disabled if log file empty (default) - Autotvm tuning during search working - tune during search (but does not load tuned records after search!) - About to add tuning to estimate_seconds - Split out the combiner rules & make them FFI friendly - Rework comments - Estimate IRModule instead of Function (closer to meta_schedule iface) - Add 'host' as first-class partitioning spec (Avoids special casing for the 'leave behind for the VM' case) - Move CollagePartitioner to very start of VM compiler flow (not changing legacy) - Fix bugs etc with new SubGraph::Rewrite approach Ready for updating RFC to focus on partitioning instead of fusion. - Working again after partition<->fusion split. - Add PrimitivePartitionRule - Refactor SubGraph Extract/Rewrite *** CAUTION: Almost certainly broken *** - Rename kernel->partition, fusion->partition - Next: make nesting in "Primitive" an explicit transform - respect existing target constraints from device planner - make 'compiler' and 'fusion_rule' attributes avail on all target kinds - moved design to tvm-rfcs, https://github.com/apache/tvm-rfcs/pull/62 - incorporate comments - avoid repeated fusion - fix trt type checking - better logs - pretty print primitive rules - fix tensorrt - multiple targets per spec - don't extract candidate function until need cost Need to bring CombineByPrimitives back under control since lost depth limit. - cleaned up fusion rule names - added 'fuse anything touching' for BYOC - Finish dd example - Add notion of 'MustLower', even if a candidate fires may still need to consider leaving node behind for VM (especially for constants). - starting example - finished all the dd sections - documentation checkpoint - docs checkpoint - more design - starting on dd - runs MNIST with TVM+CUTLASS+TRT - cutlass function-at-a-time build - need to account for build_cutlass_kernels_vm - move cutlass tuning into relay.ext.cutlass path to avoid special case - add utils - don't fuse non-scalar constants for tvm target. - stuck on cuda mem failure on conv2d, suspect bug in main - where do the cutlass attrs come from? - running, roughtly - pretty printing, signs of life - wire things up again - Switch SubGraph and CandidateKernel to TVM objects - naive CombineByKindFusionRule, just to see what we're up agaist Will switch to Object/ObjectRef for SubGraph and CandidateKernel to avoid excess copying. - preparing to mimic FuseOps - rework SubGraph to use IndexSet - rough cut at MaximalFusion - split SubGraph and IndexSet in preparation for caching input/output/entry/exit sets in SubGraph. - top-down iterative handling of sub-sub-graphs - about to give up on one-pass extraction with 'sub-sub-graphs' - Add notion of 'labels' to sub-graphs - Rework FusionRules to be more compositional - partway through reworking fusion rules, broken - SubGraph::IsValid, but still need to add no_taps check - dataflow rework, preparing for SubGraph::IsValid - explode into subdir - mnist with one fusion rule (which fires twice) working - switch to CandidateKernelIndex - Confirm can measure 'pre-annotated' primitive functions - checkpoint - stuff - more sketching - dominator logging --- CMakeLists.txt | 1 + collage_autotvm.tuninglog | 15 + include/tvm/ir/expr.h | 3 +- include/tvm/relay/expr.h | 21 + include/tvm/relay/expr_functor.h | 2 + include/tvm/relay/function.h | 2 +- include/tvm/relay/op_attr_types.h | 40 +- include/tvm/relay/transform.h | 5 + include/tvm/target/compilation_config.h | 2 + include/tvm/target/target.h | 10 + include/tvm/target/target_kind.h | 14 +- python/tvm/auto_scheduler/dispatcher.py | 2 +- python/tvm/autotvm/task/dispatcher.py | 26 +- python/tvm/contrib/cc.py | 2 + python/tvm/contrib/cutlass/build.py | 341 ++-- python/tvm/relay/__init__.py | 1 + python/tvm/relay/backend/vm.py | 101 +- python/tvm/relay/collage/__init__.py | 18 + .../tvm/relay/collage/collage_partitioner.py | 192 +++ python/tvm/relay/op/contrib/cutlass.py | 15 +- src/ir/expr.cc | 4 +- src/parser/parser.cc | 33 +- src/parser/tokenizer.h | 107 +- src/printer/doc.cc | 3 - src/printer/relay_text_printer.cc | 22 +- src/relay/backend/build_module.cc | 4 +- src/relay/backend/contrib/cutlass/codegen.cc | 14 +- src/relay/backend/contrib/tensorrt/codegen.cc | 1 + src/relay/backend/te_compiler.cc | 2 +- src/relay/backend/utils.cc | 10 +- src/relay/backend/utils.h | 4 +- src/relay/backend/vm/compiler.cc | 35 +- src/relay/backend/vm/compiler.h | 4 +- src/relay/collage/README.md | 9 + src/relay/collage/candidate_partition.cc | 207 +++ src/relay/collage/candidate_partition.h | 178 ++ .../collage/candidate_partition_index.cc | 92 + src/relay/collage/candidate_partition_index.h | 90 + src/relay/collage/candidate_set.cc | 68 + src/relay/collage/candidate_set.h | 83 + src/relay/collage/capture_index_in_spans.cc | 174 ++ src/relay/collage/capture_index_in_spans.h | 44 + src/relay/collage/collage_partitioner.cc | 312 ++++ src/relay/collage/collage_partitioner.h | 48 + src/relay/collage/combiner_rule.cc | 322 ++++ src/relay/collage/combiner_rule.h | 198 +++ src/relay/collage/cost.cc | 45 + src/relay/collage/cost.h | 103 ++ src/relay/collage/cost_estimator.cc | 69 + src/relay/collage/cost_estimator.h | 79 + src/relay/collage/dataflow_graph.cc | 48 + src/relay/collage/dataflow_graph.h | 73 + src/relay/collage/gather_partition_specs.cc | 246 +++ src/relay/collage/gather_partition_specs.h | 77 + src/relay/collage/index_set.cc | 231 +++ src/relay/collage/index_set.h | 125 ++ src/relay/collage/name_supply.cc | 87 + src/relay/collage/name_supply.h | 57 + src/relay/collage/partition_rule.cc | 385 +++++ src/relay/collage/partition_rule.h | 468 +++++ src/relay/collage/partition_spec.cc | 85 + src/relay/collage/partition_spec.h | 104 ++ src/relay/collage/priority_queue.h | 73 + src/relay/collage/prune_candidates.cc | 216 +++ src/relay/collage/prune_candidates.h | 68 + .../collage/recover_virtual_device_map.cc | 54 + .../collage/recover_virtual_device_map.h | 42 + src/relay/collage/sub_graph.cc | 958 +++++++++++ src/relay/collage/sub_graph.h | 425 +++++ src/relay/collage/utils.cc | 137 ++ src/relay/collage/utils.h | 77 + src/relay/ir/dataflow_matcher.cc | 114 +- src/relay/ir/dataflow_matcher_impl.h | 19 +- src/relay/ir/expr.cc | 39 + src/relay/ir/indexed_graph.cc | 451 +++-- src/relay/ir/indexed_graph.h | 171 +- src/relay/op/nn/nn.cc | 1 + src/relay/transforms/fuse_ops.cc | 202 ++- src/relay/transforms/type_infer.cc | 14 + .../contrib/tensorrt/tensorrt_builder.cc | 27 +- .../contrib/tensorrt/tensorrt_calibrator.h | 2 +- src/runtime/contrib/tensorrt/tensorrt_ops.cc | 95 +- src/runtime/contrib/tensorrt/tensorrt_ops.h | 7 +- .../contrib/tensorrt/tensorrt_runtime.cc | 4 +- src/runtime/cuda/cuda_device_api.cc | 10 +- src/runtime/dso_library.cc | 18 +- src/runtime/library_module.cc | 6 +- src/runtime/logging.cc | 1 - src/runtime/module.cc | 1 + src/runtime/vm/pooled_allocator.h | 13 +- src/runtime/vm/vm.cc | 9 +- src/target/compilation_config.cc | 45 + src/target/target.cc | 22 + tests/python/contrib/test_cutlass.py | 6 +- .../relay/collage/test_collage_partitioner.py | 1508 +++++++++++++++++ tests/python/relay/collage/test_sub_graph.py | 374 ++++ tests/python/relay/test_dataflow_pattern.py | 94 +- tests/python/relay/test_pass_fuse_ops.py | 382 +++-- 98 files changed, 9962 insertions(+), 886 deletions(-) create mode 100644 collage_autotvm.tuninglog create mode 100644 python/tvm/relay/collage/__init__.py create mode 100644 python/tvm/relay/collage/collage_partitioner.py create mode 100644 src/relay/collage/README.md create mode 100644 src/relay/collage/candidate_partition.cc create mode 100644 src/relay/collage/candidate_partition.h create mode 100644 src/relay/collage/candidate_partition_index.cc create mode 100644 src/relay/collage/candidate_partition_index.h create mode 100644 src/relay/collage/candidate_set.cc create mode 100644 src/relay/collage/candidate_set.h create mode 100644 src/relay/collage/capture_index_in_spans.cc create mode 100644 src/relay/collage/capture_index_in_spans.h create mode 100644 src/relay/collage/collage_partitioner.cc create mode 100644 src/relay/collage/collage_partitioner.h create mode 100644 src/relay/collage/combiner_rule.cc create mode 100644 src/relay/collage/combiner_rule.h create mode 100644 src/relay/collage/cost.cc create mode 100644 src/relay/collage/cost.h create mode 100644 src/relay/collage/cost_estimator.cc create mode 100644 src/relay/collage/cost_estimator.h create mode 100644 src/relay/collage/dataflow_graph.cc create mode 100644 src/relay/collage/dataflow_graph.h create mode 100644 src/relay/collage/gather_partition_specs.cc create mode 100644 src/relay/collage/gather_partition_specs.h create mode 100644 src/relay/collage/index_set.cc create mode 100644 src/relay/collage/index_set.h create mode 100644 src/relay/collage/name_supply.cc create mode 100644 src/relay/collage/name_supply.h create mode 100644 src/relay/collage/partition_rule.cc create mode 100644 src/relay/collage/partition_rule.h create mode 100644 src/relay/collage/partition_spec.cc create mode 100644 src/relay/collage/partition_spec.h create mode 100644 src/relay/collage/priority_queue.h create mode 100644 src/relay/collage/prune_candidates.cc create mode 100644 src/relay/collage/prune_candidates.h create mode 100644 src/relay/collage/recover_virtual_device_map.cc create mode 100644 src/relay/collage/recover_virtual_device_map.h create mode 100644 src/relay/collage/sub_graph.cc create mode 100644 src/relay/collage/sub_graph.h create mode 100644 src/relay/collage/utils.cc create mode 100644 src/relay/collage/utils.h create mode 100644 tests/python/relay/collage/test_collage_partitioner.py create mode 100644 tests/python/relay/collage/test_sub_graph.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1564a6820719e..9f4a759adae46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -291,6 +291,7 @@ tvm_file_glob(GLOB_RECURSE RELAY_OP_SRCS ) tvm_file_glob(GLOB_RECURSE RELAY_PASS_SRCS src/relay/analysis/*.cc + src/relay/collage/*.cc src/relay/transforms/*.cc src/relay/quantize/*.cc ) diff --git a/collage_autotvm.tuninglog b/collage_autotvm.tuninglog new file mode 100644 index 0000000000000..466090ae64af6 --- /dev/null +++ b/collage_autotvm.tuninglog @@ -0,0 +1,15 @@ +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "conv2d_nchw_winograd.cuda", [["TENSOR", [1, 1, 32, 32], "float32"], ["TENSOR", [8, 1, 5, 5], "float32"], [1, 1], [0, 0, 0, 0], [1, 1], "float32"], {}], "config": {"index": 968, "code_hash": null, "entity": [["tile_b", "sp", [-1, 1, 1, 1]], ["tile_y", "sp", [-1, 2, 4, 1]], ["tile_x", "sp", [-1, 1, 7, 7]], ["tile_rc", "sp", [-1, 1]], ["auto_unroll_max_step", "ot", 128], ["unroll_explicit", "ot", 1]]}, "result": [[1000000000.0], 6, 10, 1648166365.035291], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "conv2d_nchw.cuda", [["TENSOR", [1, 1, 32, 32], "float32"], ["TENSOR", [8, 1, 5, 5], "float32"], [1, 1], [0, 0, 0, 0], [1, 1], "float32"], {}], "config": {"index": 748547, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 4, 1]], ["tile_y", "sp", [-1, 1, 1, 4]], ["tile_x", "sp", [-1, 1, 14, 1]], ["tile_rc", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 5]], ["tile_rx", "sp", [-1, 5]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]]}, "result": [[2.1807114592422733e-06, 2.182203281316585e-06, 2.183491385782991e-06], 0, 1.8035461902618408, 1648233194.5253587], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "conv2d_nchw_winograd.cuda", [["TENSOR", [1, 8, 18, 18], "float32"], ["TENSOR", [16, 8, 5, 5], "float32"], [1, 1], [0, 0, 0, 0], [1, 1], "float32"], {}], "config": {"index": 7905, "code_hash": null, "entity": [["tile_b", "sp", [-1, 1, 1, 1]], ["tile_y", "sp", [-1, 1, 4, 4]], ["tile_x", "sp", [-1, 1, 49, 1]], ["tile_rc", "sp", [-1, 4]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]]}, "result": [[1.4285206158127155e-05, 1.4285846107313532e-05, 1.4331592281168714e-05], 0, 7.421089172363281, 1648237434.129], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "conv2d_nchw.cuda", [["TENSOR", [1, 8, 18, 18], "float32"], ["TENSOR", [16, 8, 5, 5], "float32"], [1, 1], [0, 0, 0, 0], [1, 1], "float32"], {}], "config": {"index": 714012, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 1]], ["tile_y", "sp", [-1, 1, 1, 1]], ["tile_x", "sp", [-1, 1, 7, 2]], ["tile_rc", "sp", [-1, 8]], ["tile_ry", "sp", [-1, 5]], ["tile_rx", "sp", [-1, 5]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]]}, "result": [[2.5586838960333487e-06, 2.5701070606157226e-06, 2.572374535019662e-06], 0, 3.1794843673706055, 1648239614.7956486], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_small_batch.gpu", [["TENSOR", [1, 256], "float32"], ["TENSOR", [10, 256], "float32"], null, "float32"], {}], "config": {"index": 4, "code_hash": null, "entity": [["tile_k", "sp", [-1, 16]]]}, "result": [[2.158152404676017e-06, 2.1645748896629425e-06, 2.1784918293729133e-06], 0, 1.6369056701660156, 1648241555.184448], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_large_batch.gpu", [["TENSOR", [1600, 768], "float32"], ["TENSOR", [2304, 768], "float32"], null, "float32"], {}], "config": {"index": 61851361, "code_hash": null, "entity": [["tile_x", "sp", [-1, 2, 2, 8]], ["tile_y", "sp", [-1, 1, 2, 9]], ["tile_k", "sp", [-1, 2, 4]]]}, "result": [[0.004074227972972973, 0.0040861373243243244, 0.004086151648648648], 0, 3.037601947784424, 1648251189.6885986], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_small_batch.gpu", [["TENSOR", [1600, 768], "float32"], ["TENSOR", [2304, 768], "float32"], null, "float32"], {}], "config": {"index": 5, "code_hash": null, "entity": [["tile_k", "sp", [-1, 8]]]}, "result": [[0.0268318398, 0.026832641350000002, 0.02683273135], 0, 4.179340600967407, 1648254281.8060668], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "batch_matmul.cuda", [["TENSOR", [600, 32, 64], "float32"], ["TENSOR", [600, 32, 64], "float32"], [600, 32, 32], "float32", 0, 1], {}], "config": {"index": 20386, "code_hash": null, "entity": [["tile_y", "sp", [-1, 2, 8]], ["tile_x", "sp", [-1, 16, 1]], ["tile_k", "sp", [-1, 16]], ["auto_unroll_max_step", "ot", 32], ["unroll_explicit", "ot", 1]]}, "result": [[3.258110773592547e-05, 3.258372944511948e-05, 3.261549426218442e-05], 0, 2.397996664047241, 1648255266.3718677], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "batch_matmul.cuda", [["TENSOR", [600, 32, 32], "float32"], ["TENSOR", [600, 64, 32], "float32"], [600, 32, 64], "float32", 0, 1], {}], "config": {"index": 5980, "code_hash": null, "entity": [["tile_y", "sp", [-1, 2, 8]], ["tile_x", "sp", [-1, 16, 1]], ["tile_k", "sp", [-1, 16]], ["auto_unroll_max_step", "ot", 16], ["unroll_explicit", "ot", 0]]}, "result": [[3.199404780823732e-05, 3.199749384187525e-05, 3.200219666269368e-05], 0, 2.3573713302612305, 1648257050.9987426], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_large_batch.gpu", [["TENSOR", [1600, 768], "float32"], ["TENSOR", [768, 768], "float32"], null, "float32"], {}], "config": {"index": 13482935, "code_hash": null, "entity": [["tile_x", "sp", [-1, 5, 16, 1]], ["tile_y", "sp", [-1, 4, 16, 2]], ["tile_k", "sp", [-1, 12, 2]]]}, "result": [[0.00026185516898148144, 0.00026186912731481486, 0.0002643642638888889], 0, 5.9183220863342285, 1648262140.4419408], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_small_batch.gpu", [["TENSOR", [1600, 768], "float32"], ["TENSOR", [768, 768], "float32"], null, "float32"], {}], "config": {"index": 9, "code_hash": null, "entity": [["tile_k", "sp", [-1, 32]]]}, "result": [[0.0022258066376811595, 0.0022258676666666666, 0.0022260689855072464], 0, 1.6845574378967285, 1648264221.272429], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_large_batch.gpu", [["TENSOR", [1600, 768], "float32"], ["TENSOR", [3072, 768], "float32"], null, "float32"], {}], "config": {"index": 75386735, "code_hash": null, "entity": [["tile_x", "sp", [-1, 5, 16, 1]], ["tile_y", "sp", [-1, 2, 16, 4]], ["tile_k", "sp", [-1, 2, 12]]]}, "result": [[0.0009476383928571428, 0.0009476764880952381, 0.0009480008333333333], 0, 3.346571207046509, 1648271350.9854434], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_small_batch.gpu", [["TENSOR", [1600, 768], "float32"], ["TENSOR", [3072, 768], "float32"], null, "float32"], {}], "config": {"index": 17, "code_hash": null, "entity": [["tile_k", "sp", [-1, 768]]]}, "result": [[1000000000.0], 4, 4.362995386123657, 1648274146.1389868], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_large_batch.gpu", [["TENSOR", [1600, 3072], "float32"], ["TENSOR", [768, 3072], "float32"], null, "float32"], {}], "config": {"index": 15171048, "code_hash": null, "entity": [["tile_x", "sp", [-1, 5, 4, 20]], ["tile_y", "sp", [-1, 1, 192, 2]], ["tile_k", "sp", [-1, 8, 2]]]}, "result": [[1000000000.0], 1, 1.2985179424285889, 1648274382.1135368], "version": 0.2, "tvm_version": "0.9.dev0"} +{"input": ["cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", "dense_small_batch.gpu", [["TENSOR", [1600, 3072], "float32"], ["TENSOR", [768, 3072], "float32"], null, "float32"], {}], "config": {"index": 9, "code_hash": null, "entity": [["tile_k", "sp", [-1, 32]]]}, "result": [[1000000000.0], 4, 4.3437583446502686, 1648274480.7225487], "version": 0.2, "tvm_version": "0.9.dev0"} diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4a00de802c61e..b54a067e1c941 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(String name_hint, Type type = {}); + TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); }; // PrimExprs that are useful as runtime containers. diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index fe570806922fd..c5f94be699738 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -39,6 +39,12 @@ #include "./type.h" namespace tvm { + +GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint = {}, + Optional opt_type = {}, + Optional opt_virtual_device = {}, + Optional opt_span = {}); + namespace relay { using Expr = tvm::RelayExpr; @@ -97,8 +103,23 @@ class Constant : public Expr { TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); }; +/*! + * \brief Returns the constant with given properties. A null property denotes 'no change'. + * Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param constant The constant to copy + * \param op_data The (optional) data for the copied constant. If none, ret_constant->data = + * constant->data. + * \param opt_virtual_device The (optional) virtual_device for the copied constant. If none, + * ret_constant->virtual_device = constant->virtual_device. + * \param opt_span The (optional) span for the copied constant. If none, + * ret_constant->span = constant->span. + */ +Constant WithFields(Constant constant, Optional opt_data = {}, + Optional opt_virtual_device = {}, Optional opt_span = {}); + /*! \brief Tuple of multiple Exprs */ class Tuple; /*! \brief Tuple container */ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index d8f575dfdf485..280a1f8a6c29c 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -240,6 +240,8 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { */ explicit MixedModeVisitor(int visit_limit = 1); + using ExprVisitor::VisitExpr_; + /*! * \brief VisitExpr is finalized to preserve call expansion of dataflow regions */ diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 5869f878aa856..50c3cffba9179 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -173,7 +173,7 @@ namespace attr { /*! \brief Mark the function as a primitive function. */ constexpr const char* kPrimitive = "Primitive"; /*! - * \brief Indicate the compiler that should be used for building this function. + * \brief Indicate the BYOC compiler that should be used for building this function. * When this is unset or set to "default", the default compilation pipeline will be used. */ constexpr const char* kCompiler = "Compiler"; diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 97a3d5e2a01f4..d4efdae4ccd74 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -41,24 +41,40 @@ using tir::BijectiveLayoutNode; using tir::Layout; using tir::LayoutAxis; -/*! \brief operator pattern used in graph fusion */ +/*! + * \brief Operator pattern used to guide fusion. + * + * + * + */ enum OpPatternKind { - // Elementwise operation + // Elementwise operator, eg relu. + // \code + // out[i, j, k] = op(in[i, j, k]) + // \endcode + // The underlying scalar op can always be moved to the point the input tensor was created. kElemWise = 0, - // Broadcasting operator, can always map output axis to the input in order. - // for example :code:`out[i, ax1, j, ax2] = input[i, j]`. - // Note that the axis need to be in order so transpose is not a bcast operator. + // Broadcasting operator, eg add. + // As for kElemWise, but some output axes may be broadcasted, and the remaining must correspond + // to input axes in order. + // \code + // out[i, j, k] = op(in[i, j]) + // \endcode + // (So transpose is not a kBroadcast). kBroadcast = 1, - // Injective operator, can always injectively map output axis to a single input axis. - // All injective operator can still be safely fused to injective and reduction. + // Injective operator, eg concat. + // Can always injectively map output axis to a single input axis. + // All kInjecting operators can be fused to kInjective and kCommReduce operators. + // Eg: concatenate kInjective = 2, - // Communicative reduction operator. + // Communicative reduction operator, eg sum. kCommReduce = 3, - // Complex operation, can still fuse elemwise operations into its output. - // but cannot chain another complex op + // Complex operation, eg conv2d. Often called the fused sub-graph's 'anchor node'. + // Can fuse kElemWise operations into its output, but cannot fuse additional kOutEWiseFusable + // operations. kOutEWiseFusable = 4, - // The pattern for tuple nodes. Can fuse into subsequent injective ops, - // but treated specially + // A tuple. + // Can fuse into subsequent injective ops, but treated specially. kTuple = 7, // Opaque operation, cannot fuse anything. kOpaque = 8 diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 4a6b06f14f947..ffd0a87a16160 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -273,6 +273,11 @@ TVM_DLL Pass InferType(); */ TVM_DLL Type InferTypeLocal(const Expr& expr); +/*! + * \brief Infer the types of all sub-expression of expr. + */ +TVM_DLL Expr InferTypeExpr(const Expr& expr); + /*! * \brief Search and eliminate common subexpression. For example, if there are * two expressions evaluated to an identical value, a single variable is created diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h index 1c47a0f806a34..525effb69fda0 100644 --- a/include/tvm/target/compilation_config.h +++ b/include/tvm/target/compilation_config.h @@ -171,6 +171,8 @@ class CompilationConfig : public ObjectRef { TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx, TargetMap legacy_target_map_arg, Target optional_host_target_arg); + TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx, Array targets); + TVM_DEFINE_OBJECT_REF_METHODS(CompilationConfig, ObjectRef, CompilationConfigNode); }; diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 21760bdc8dbf8..5df4551ac8c8e 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -177,7 +177,17 @@ class Target : public ObjectRef { */ static Target WithHost(const Target& target, const Target& host); + /*! + * \brief Returns true if \p this is a 'refinement of' \p that. Ie \p this + * and \p that are structurally equivalent except \p this may have 'compiler' and/or 'fusion_rule' + * attributes + */ + bool IsRefinementOf(const Target& that) const; + private: + Target(TargetKind kind, Optional host, String tag, Array keys, + Map attrs); + // enable with syntax. friend class TargetInternal; friend class With; diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index e802a3088d2d9..993d28f0fe5b7 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -384,6 +384,16 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { #define TVM_TARGET_KIND_REGISTER_VAR_DEF \ static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetKindRegEntry& __make_##TargetKind +/* Special attributes on all target kinds: + * "compiler": If set, the BYOC toolchain name this target is specialized to. This name appears: + * - In the BYOC lowering function registered as "ext.relay.". + * - As the "Compiler" attribute on "Primitive" functions. + * - In the operator predicate bound to the operator attribute "target." + * - In a @register_pattern_table("") annotation. + * "fusion_rule": If set, the FusionRule to use for this target in the CollageFuseOps pass. + * If missing, use built-in rules to derive the required FusionSpec. + */ + /*! * \def TVM_REGISTER_TARGET_KIND * \brief Register a new target kind, or set attribute of the corresponding target kind. @@ -412,7 +422,9 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { .add_attr_option("model") \ .add_attr_option>("libs") \ .add_attr_option("host") \ - .add_attr_option("from_device") + .add_attr_option("from_device") \ + .add_attr_option("compiler") \ + .add_attr_option("partition_rule") } // namespace tvm diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index cc1e76b9faa8f..22d094dae0368 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -332,7 +332,7 @@ class ApplyHistoryBestOrSample(ApplyHistoryBest): """ def __init__( - self, records, sample_simple_workloads=False, cost_model_file=None, num_measure=-1 + self, records, sample_simple_workloads=False, cost_model_file=None, num_measure=-1 ): self.sample_simple_workloads = sample_simple_workloads self.num_measure = num_measure diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index bed02581270ec..5239054398f80 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -55,6 +55,9 @@ class DispatchContext(object): def __init__(self): self._old_ctx = DispatchContext.current + def contains(self, target, workload): + raise NotImplementedError() + def query(self, target, workload): """ Query the context to get the specific config for a template. @@ -227,9 +230,11 @@ def load(self, records): counter = 0 for inp, res in records: + #logger.info(f"inp={inp}, res={res}") counter += 1 - if res.error_no != 0: - continue + #TODO(mbs): Cache error + #if res.error_no != 0: + # continue # use target keys in tvm target system as key to build best map for k in inp.target.keys: @@ -251,7 +256,12 @@ def load(self, records): if np.mean(other_res.costs) > np.mean(res.costs): best_by_model[key] = (inp, res) - logger.debug("Finish loading %d records", counter) + #logger.info("Finished loading %d records", counter) + + def contains(self, target, workload): + #logger.info( + # f"look for match with {target} and {workload} with {len(self._best_user_defined)} user-defined, {len(self.best_by_model)} model and {len(self.best_by_targetkey)} target entries") + return self._query_inside(target, workload) is not None def _query_inside(self, target, workload): if target is None: @@ -311,8 +321,8 @@ def _query_inside(self, target, workload): if not _env.GLOBAL_SCOPE.silent: msg = ( - "Cannot find config for target=%s, workload=%s. A fallback configuration " - "is used, which may bring great performance regression." % (target, workload) + "Cannot find config for target=%s, workload=%s. A fallback configuration " + "is used, which may bring great performance regression." % (target, workload) ) if msg not in DispatchContext.warning_messages: DispatchContext.warning_messages.add(msg) @@ -426,9 +436,9 @@ def _query_inside(self, target, workload): key = (str(target), workload) if key not in self._global_cfg_dict: msg = ( - "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " - "A fallback configuration is used, which may bring great performance " - "regression." % (target, workload) + "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " + "A fallback configuration is used, which may bring great performance " + "regression." % (target, workload) ) logger.warning(msg) cfg = FallbackConfigEntity() diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 867cbd6012563..5e9dd0f94ea97 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -19,6 +19,7 @@ import sys import os import subprocess +import logging from .._ffi.base import py_str @@ -238,6 +239,7 @@ def _linux_compile(output, objects, options, compile_cmd, compile_shared=False): cmd += objects if options: cmd += options + logging.info(f"invoking '{cmd}'") proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() if proc.returncode != 0: diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index bd372572c403b..ac6dcf8f85275 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -17,11 +17,13 @@ # pylint: disable=invalid-name, dangerous-default-value """Driver for partitioning and building a Relay module for CUTLASS offload.""" import logging +import tempfile import os import multiprocessing import tvm from tvm import runtime, relay from tvm.contrib.nvcc import get_cuda_version +from tvm._ffi.registry import register_func from .gen_gemm import CutlassGemmProfiler from .gen_conv2d import CutlassConv2DProfiler from .library import ConvKind @@ -77,7 +79,7 @@ def __init__(self): def visit_call(self, call): op = call.op - if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs: + if isinstance(op, relay.Function) and "Composite" in op.attrs: self.signature["op_type"] = op.attrs["Composite"] for i, arg in enumerate(op.params): self.signature["arg%d_shape" % i] = arg.checked_type.shape @@ -94,18 +96,18 @@ def visit_call(self, call): def select_gemm_kernel( - cutlass_profiler, - op_type, - MM, - KK, - NN, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - batched, - find_first_valid, - use_multiprocessing, + cutlass_profiler, + op_type, + MM, + KK, + NN, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + batched, + find_first_valid, + use_multiprocessing, ): """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic workloads.""" @@ -138,16 +140,16 @@ def select_gemm_kernel( def handle_batch_matmul( - cutlass_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, ): """Profile and select a kernel for batch_matmul op workload.""" MM = arg0_shape[1] @@ -183,16 +185,16 @@ def handle_batch_matmul( def handle_dense( - cutlass_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, ): """Profile and select a kernel for dense op workload.""" MM = arg0_shape[0] @@ -226,21 +228,21 @@ def handle_dense( def handle_conv2d( - cutlass_profiler, - op_type, - d_shape, - w_shape, - padding, - strides, - dilation, - out_dtype, - data_dtype, - weight_dtype, - use_3xtf32, - split_k_slices, - profile_all_alignments, - find_first_valid, - use_multiprocessing, + cutlass_profiler, + op_type, + d_shape, + w_shape, + padding, + strides, + dilation, + out_dtype, + data_dtype, + weight_dtype, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, ): """Profile and select a kernel for conv2d op workload.""" if "conv2d_transpose" in op_type: @@ -286,14 +288,14 @@ def handle_conv2d( def tune_cutlass_kernels( - mod, - sm, - use_3xtf32=True, - split_k_slices=[1], - profile_all_alignments=False, - find_first_valid=False, - use_multiprocessing=False, - tmp_dir="./tmp", + mod, + sm, + use_3xtf32=True, + split_k_slices=[1], + profile_all_alignments=False, + find_first_valid=False, + use_multiprocessing=False, + tmp_dir="/tmp", ): """Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. @@ -340,110 +342,124 @@ def tune_cutlass_kernels( num_cutlass_partition : int The number of partitioned functions created for CUTLASS. """ - gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) - conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir) num_cutlass_partition = 0 for var in mod.get_global_vars(): fun_name = var.name_hint func = mod[fun_name] - annotator = OpAnnotator() if "cutlass" in fun_name: num_cutlass_partition += 1 - annotator.visit(func) - out_shape = annotator.signature["ret_shape"] - out_dtype = annotator.signature["ret_dtype"] - op_type = annotator.signature["op_type"] - - new_attrs = {"op_type": op_type} - new_attrs.update(annotator.signature) - new_attrs.update(func.attrs) - arg0_shape = new_attrs["arg0_shape"] - arg1_shape = new_attrs["arg1_shape"] - arg0_dtype = new_attrs["arg0_dtype"] - arg1_dtype = new_attrs["arg1_dtype"] - - if "conv2d" in op_type: - new_attrs["padding"] = annotator.op_attrs.padding - new_attrs["strides"] = annotator.op_attrs.strides - new_attrs["dilation"] = annotator.op_attrs.dilation - - if "conv2d_transpose" in op_type: - d_shape = out_shape - w_shape = arg1_shape - elif "conv2d_backward_weight" in op_type: - d_shape = arg1_shape - w_shape = out_shape - else: - d_shape = arg0_shape - w_shape = arg1_shape - - new_attrs.update( - handle_conv2d( - conv2d_profiler, - op_type, - d_shape, - w_shape, - annotator.op_attrs.padding, - annotator.op_attrs.strides, - annotator.op_attrs.dilation, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - split_k_slices, - profile_all_alignments, - find_first_valid, - use_multiprocessing, - ) - ) - elif "batch_matmul" in op_type: - new_attrs.update( - handle_batch_matmul( - gemm_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, - ) - ) - elif "dense" in op_type: - new_attrs.update( - handle_dense( - gemm_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, - ) - ) - else: - raise ValueError("%s unsupported composite" % op_type) - - new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) - new_func = relay.Function( - func.params, - func.body, - ret_type=func.ret_type, - type_params=func.type_params, - attrs=new_attrs, - ) + new_func = tune_cutlass_function(func, sm, use_3xtf32, split_k_slices, profile_all_alignments, + find_first_valid, use_multiprocessing, tmp_dir) mod.update_func(var, new_func) return mod, num_cutlass_partition +def tune_cutlass_function( + func, + sm, + use_3xtf32=True, + split_k_slices=[1], + profile_all_alignments=False, + find_first_valid=False, + use_multiprocessing=False, + tmp_dir="/tmp", +): + gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) + conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir) + annotator = OpAnnotator() + annotator.visit(func) + out_shape = annotator.signature["ret_shape"] + out_dtype = annotator.signature["ret_dtype"] + op_type = annotator.signature["op_type"] + + new_attrs = {"op_type": op_type} + new_attrs.update(annotator.signature) + new_attrs.update(func.attrs) + arg0_shape = new_attrs["arg0_shape"] + arg1_shape = new_attrs["arg1_shape"] + arg0_dtype = new_attrs["arg0_dtype"] + arg1_dtype = new_attrs["arg1_dtype"] + + if "conv2d" in op_type: + new_attrs["padding"] = annotator.op_attrs.padding + new_attrs["strides"] = annotator.op_attrs.strides + new_attrs["dilation"] = annotator.op_attrs.dilation + + if "conv2d_transpose" in op_type: + d_shape = out_shape + w_shape = arg1_shape + elif "conv2d_backward_weight" in op_type: + d_shape = arg1_shape + w_shape = out_shape + else: + d_shape = arg0_shape + w_shape = arg1_shape + + new_attrs.update( + handle_conv2d( + conv2d_profiler, + op_type, + d_shape, + w_shape, + annotator.op_attrs.padding, + annotator.op_attrs.strides, + annotator.op_attrs.dilation, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + ) + ) + elif "batch_matmul" in op_type: + new_attrs.update( + handle_batch_matmul( + gemm_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, + ) + ) + elif "dense" in op_type: + new_attrs.update( + handle_dense( + gemm_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, + ) + ) + else: + raise ValueError("%s unsupported composite" % op_type) + + new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) + return relay.Function( + func.params, + func.body, + ret_type=func.ret_type, + type_params=func.type_params, + attrs=new_attrs, + ) + + def build_cutlass_kernels( - lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1, use_fast_math=False + lib, sm, tmp_dir="/tmp", lib_path="compile.so", threads=-1, use_fast_math=False ): """Compile CUTLASS kernels in lib and return the runtime module ready to run. @@ -480,13 +496,13 @@ def build_cutlass_kernels( def build_cutlass_kernels_vm( - vm_exec, - sm, - tmp_dir="./tmp", - lib_path="compile.so", - vmcode_path="vmcode.ro", - threads=-1, - use_fast_math=False, + vm_exec, + sm, + tmp_dir="/tmp", + lib_path="compile.so", + vmcode_path="vmcode.ro", + threads=-1, + use_fast_math=False, ): """Compile CUTLASS kernels in vm_exec and return a VM executable ready to run. @@ -530,3 +546,18 @@ def build_cutlass_kernels_vm( lib = tvm.runtime.load_module(lib_path) code = bytearray(open(vmcode_path, "rb").read()) return tvm.runtime.vm.Executable.load_exec(code, lib) + + +_create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module") + + +@register_func("relay.ext.cutlass") +def cutlass_compiler(function): + # TODO(mbs): Get cutlass options from target annotation on function + sm = 80 + name = function.attrs["global_symbol"] + function = tune_cutlass_function(function, sm) + mod = _create_c_source_module(function) + tmp_dir = tempfile.mkdtemp() + lib_path = tmp_dir + f"/{name}.so" + return build_cutlass_kernels(mod, sm, tmp_dir=tmp_dir, lib_path=lib_path) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 89c8fcb17d731..97842738e5cd4 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -32,6 +32,7 @@ from . import transform from . import analysis +from . import collage from .build_module import build, create_executor, optimize from .transform import build_config from . import debug diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 25744408d87b3..544bb894f8087 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -65,14 +65,7 @@ def compile(mod, target=None, target_host=None, params=None): exec : tvm.runtime.vm.Executable The VM executable that contains both library code and bytecode. """ - if target_host is not None: - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - target, target_host = Target.check_and_update_host_consist( - target, target_host, target_is_dict_key=False - ) + assert target_host is None compiler = VMCompiler() if params: compiler.set_params(params) @@ -139,20 +132,12 @@ def lower(self, mod, target=None, target_host=None): By default, llvm is used if it is enabled, otherwise a stackvm intepreter is used. """ - if target_host is not None: - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - target = self._update_target(target) - target_host = self._update_target_host(target, target_host) - target, target_host = Target.check_and_update_host_consist( - target, target_host, target_is_dict_key=False - ) - - tophub_context = self._tophub_context(target) - with tophub_context: - self._lower(mod, target, target_host) + assert target_host is None + if isinstance(target, dict): + target = [t for d, t in target] + elif isinstance(target, Target): + target = [target] + self._lower(mod, target) def codegen(self): """Generate the kernel library.""" @@ -185,20 +170,14 @@ def optimize(self, mod, target=None, target_host=None, params=None): params : dict The parameters of the final module. """ - if target_host is not None: - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - target = self._update_target(target) - target_host = self._update_target_host(target, target_host) - target, target_host = Target.check_and_update_host_consist( - target, target_host, target_is_dict_key=False - ) - + assert target_host is None + if isinstance(target, dict): + target = [t for d, t in target] + elif isinstance(target, Target): + target = [target] if params: self.set_params(params) - return self._optimize(mod, target, target_host), self.get_params() + return self._optimize(mod, target), self.get_params() def get_exec(self): """Get the VM executable. @@ -210,60 +189,6 @@ def get_exec(self): """ return vm_rt.Executable(self._get_exec()) - def _update_target(self, target): - """Update target.""" - target = target if target else tvm.target.Target.current() - if target is None: - raise ValueError("Target is not set in env or passed as argument.") - - if isinstance(target, str): - target = {target: target} - elif isinstance(target, tvm.target.Target): - target = {target.kind.name: target} - elif not isinstance(target, dict): - raise TypeError( - "target is expected to be str, tvm.target.Target, " - + "or dict of str to str/tvm.target.Target, but received " - + "{}".format(type(target)) - ) - - tgts = {} - for dev, tgt in target.items(): - dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type) - if isinstance(tgt, str): - tgt = tvm.target.Target(tgt) - - tgts[dev_type] = tgt - - return tgts - - def _update_target_host(self, target, target_host): - """Update target host.""" - target_host = None if target_host == "" else target_host - if not target_host: - for _, tgt in target.items(): - if tgt.host is not None: - return tgt.host - for device_type, tgt in target.items(): - if device_type.value == tvm.nd.cpu(0).device_type: - target_host = tgt - break - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - if isinstance(target_host, str): - target_host = tvm.target.Target(target_host) - return target_host - - def _tophub_context(self, target): - """Get the autotvm context.""" - # If current dispatch context is fallback context (the default root context), - # then load pre-tuned parameters from TopHub - if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): - tophub_context = autotvm.tophub.context(list(target.values())) - else: - tophub_context = autotvm.utils.EmptyContext() - return tophub_context - class VMExecutor(Executor): """ diff --git a/python/tvm/relay/collage/__init__.py b/python/tvm/relay/collage/__init__.py new file mode 100644 index 0000000000000..bb77f69a7c2cb --- /dev/null +++ b/python/tvm/relay/collage/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +from .collage_partitioner import * diff --git a/python/tvm/relay/collage/collage_partitioner.py b/python/tvm/relay/collage/collage_partitioner.py new file mode 100644 index 0000000000000..e47ff84a6b39b --- /dev/null +++ b/python/tvm/relay/collage/collage_partitioner.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Search for optimal partitionings over Relay models.""" + +import tvm +import numpy as np +from tvm._ffi.registry import register_func +import logging +import os +import shutil + +AUTOTVM_NUM_TRIALS = 2000 +AUTOTVM_EARLY_STOPPING = 600 +MEASURE_NUMBER = 20 +MEASURE_REPEAT = 5 + + +def arg_for(type, device): + """Returns a test argument of type on device""" + if isinstance(type, tvm.ir.TensorType): + return tvm.nd.array(np.random.uniform(-1.0, 1.0, size=type.concrete_shape).astype(type.dtype), + device=device) + elif isinstance(type, tvm.ir.TupleType): + return tuple([arg_for(field_type, device) for field_type in type.fields]) + else: + assert False, "unexpected argument type" + + +def get_tuning_log_filename(): + """Returns the autotvm tuning log filename from the current pass context. If the filename is empty then + autotvm tuning is disabled and only default schedules will be used.""" + pass_ctx = tvm.transform.PassContext.current() + if "relay.collage.autotvm_log_filename" in pass_ctx.config: + return pass_ctx.config["relay.collage.autotvm_log_filename"] + else: + return "" + + +def is_already_tuned(task, log_filename): + """Returns true if we already have a tuning record for task in turning logs in log_filename""" + if not os.path.exists(log_filename): + return False + + dispatch_context = tvm.autotvm.apply_history_best(log_filename) + return dispatch_context.contains(task.target, task.workload) + + +def tune_autotvm_tasks(tasks, log_filename): + """Appends to log_filename the best strategies for tasks""" + if len(tasks) == 0: + if not os.path.exists(log_filename): + # Ensure we always have a log file, even if empty. + with open(log_filename, 'w'): + pass + return + + measure_option = tvm.autotvm.measure_option( + builder=tvm.autotvm.LocalBuilder(timeout=10), + runner=tvm.autotvm.LocalRunner( + number=MEASURE_NUMBER, repeat=MEASURE_REPEAT, timeout=4, min_repeat_ms=150), + ) + + logging.info( + f"Using autotvm tuning for {len(tasks)} tasks with {AUTOTVM_NUM_TRIALS} trials, logging to {log_filename}") + + # create tmp log file, starting with contents from existing log file + tmp_log_filename = log_filename + ".tmp" + if os.path.exists(tmp_log_filename): + os.remove(tmp_log_filename) + if os.path.exists(log_filename): + logging.info(f"Copying existing log {log_filename} to {tmp_log_filename}") + shutil.copy(log_filename, tmp_log_filename) + + for i, task in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) + logging.info(f"Considering task {task.name} {prefix}") + if is_already_tuned(task, tmp_log_filename): + logging.info(f"Re-using existing record for {task.name}") + continue + + logging.info(f"Using autotvm to tune {task.name}") + tuner_obj = tvm.autotvm.tuner.XGBTuner(task, loss_type="rank") + if os.path.exists(tmp_log_filename): + tuner_obj.load_history(tvm.autotvm.record.load_from_file(tmp_log_filename)) + + # do tuning + n_trial = min(AUTOTVM_NUM_TRIALS, len(task.config_space)) + tuner_obj.tune( + n_trial=n_trial, + early_stopping=AUTOTVM_EARLY_STOPPING, + measure_option=measure_option, + callbacks=[ + tvm.autotvm.callback.progress_bar(n_trial, prefix=prefix), + tvm.autotvm.callback.log_to_file(tmp_log_filename), + ], + ) + + # pick best records and copy back to main log file + tvm.autotvm.record.pick_best(tmp_log_filename, log_filename) + os.remove(tmp_log_filename) + + logging.info("Done with autotvm tuning") + + +@register_func("tvm.relay.collage.estimate_seconds") +def estimate_seconds(mod, target): + """Returns the mean execution time of "main" in mod on target with params. The module + may contain "Primitive" functions, possibly with "Compiler" attributes.""" + device = tvm.device(target.kind.device_type) + + # Though nothing goes wrong, it makes debugging hard if we recursively invoke the + # CollagePartitioner when trying to compile a candidate partition. So just disable it. + config = {} + for k, v in tvm.transform.PassContext.current().config.items(): + config[k] = v + config["relay.collage.enable_collage"] = False + with tvm.transform.PassContext(config=config): + log_filename = get_tuning_log_filename() + if log_filename == "": + logging.info("Not tuning with autotvm since disabled") + dispatch_context = tvm.autotvm.task.FallbackContext() + else: + # Extract and tune any TVM kernels. BYOC partitions will have no tasks extracted. + tasks = tvm.autotvm.task.extract_from_program(mod["main"], target=target, params=None) + tune_autotvm_tasks(tasks, log_filename) + # Continue compilation with all the tuning records we have so far + dispatch_context = tvm.autotvm.task.ApplyHistoryBest(log_filename) + + with dispatch_context: + # Build the module. + exe = tvm.relay.vm.compile(mod, target) + + # Benchmark the module. + vm = tvm.runtime.vm.VirtualMachine(exe, device) + main = mod["main"] + args = [arg_for(v.checked_type, device) for v in main.params] + benchmark_result = vm.benchmark(device, *args, repeat=MEASURE_REPEAT, number=MEASURE_NUMBER) + logging.info(benchmark_result) + + return benchmark_result.mean # seconds + + +@register_func("tvm.relay.collage.establish_autotvm_logs") +def establish_autotvm_logs(): + """Establishes the autotvm tuning context to use for the remainder of the computation. + It will remain in place for the remainder of compilation.""" + log_filename = get_tuning_log_filename() + if log_filename != "": + dispatch_context = tvm.autotvm.task.ApplyHistoryBest(log_filename) + dispatch_context.__enter__() + + +make_labelled_dfpattern_partition_rule = tvm._ffi.get_global_func( + "relay.collage.make_labelled_dfpattern_partition_rule") +make_labelled_dfpattern_partition_rule_with_predicate = tvm._ffi.get_global_func( + "relay.collage.make_labelled_dfpattern_partition_rule_with_predicate") +make_pattern_byoc_partition_rule = tvm._ffi.get_global_func("relay.collage.make_pattern_byoc_partition_rule") + + +def make_labelled_dfpattern_partition_rule_wrapper(compiler, tuple): + if len(tuple) == 2: + rule_name, dataflow_pattern = tuple + return make_labelled_dfpattern_partition_rule(compiler, rule_name, dataflow_pattern) + else: + rule_name, dataflow_pattern, predicate = tuple + return make_labelled_dfpattern_partition_rule_with_predicate(compiler, rule_name, dataflow_pattern, predicate) + + +@register_func("tvm.relay.collage.make_byoc_partition_rule") +def make_byoc_partition_rule(compiler): + """Returns the PartitionRule for BYOC compiler""" + pattern_table = tvm.relay.op.contrib.get_pattern_table(compiler) + assert pattern_table is not None, f"No pattern table entry was found for BYOC compiler {compiler}" + logging.info( + f"Converting {len(pattern_table)} rules for {compiler} for use in pattern style BYOC lowering/codegen") + sub_rules = [make_labelled_dfpattern_partition_rule_wrapper(compiler, tuple) for tuple in pattern_table] + return make_pattern_byoc_partition_rule(compiler, sub_rules) diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 5c906f7e69bed..eb11008a6d9ed 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -22,6 +22,7 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from ...dataflow_pattern import wildcard, is_op, is_constant +from tvm.relay.op.contrib.register import register_pattern_table # type: ignore def make_gelu_pattern(bias_out, out_dtype="float16"): @@ -200,8 +201,8 @@ def check_conv2d_residual(call, binary_op): return all(x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape)) -def partition_for_cutlass(mod, params=None): - """Partition the input module into CUTLASS-supported subgraphs.""" +@register_pattern_table("cutlass") +def pattern_table(): dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) @@ -273,9 +274,11 @@ def partition_for_cutlass(mod, params=None): ) ) - cutlass_patterns = ( - residual_block_patterns + dense_patterns + conv2d_patterns + conv2d_grad_patterns - ) + return residual_block_patterns + dense_patterns + conv2d_patterns + conv2d_grad_patterns + + +def partition_for_cutlass(mod, params=None): + """Partition the input module into CUTLASS-supported subgraphs.""" if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) @@ -290,6 +293,8 @@ def partition_for_cutlass(mod, params=None): with PassContext(opt_level=3): mod = remove_bn_pass(mod) + cutlass_patterns = relay.op.contrib.get_pattern_table("cutlass") + seq = Sequential( [ transform.InferType(), diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 399873492f041..72f43cf093a8a 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -24,6 +24,7 @@ #include #include #include + // NOTE: reverse dependency on top/tir. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -141,10 +142,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(String name_hint, Type type) { +GlobalVar::GlobalVar(String name_hint, Type type, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->checked_type_ = std::move(type); + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 9b15893092f7f..fdecbfe5c5f23 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -533,33 +533,32 @@ class Parser { /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ NDArray NumberToNDArray(const Token& token) { + DLDevice dev = {DLDeviceType::kDLCPU, 0}; if (token->token_type == TokenType::kInteger) { - DLDevice dev = {DLDeviceType::kDLCPU, 0}; - int64_t i = Downcast(token->data); - if (i > std::numeric_limits::max()) { - auto dtype = String2DLDataType("int64"); - auto data = NDArray::Empty({}, dtype, dev); + auto int_imm = Downcast(token->data); + auto data = NDArray::Empty({}, int_imm->dtype, dev); + if (int_imm.dtype() == DataType::Int(64)) { auto array = reinterpret_cast(data->data); // revisit this, literal node issue. - array[0] = i; - return data; + array[0] = int_imm->value; } else { - auto dtype = String2DLDataType("int32"); - auto data = NDArray::Empty({}, dtype, dev); auto array = reinterpret_cast(data->data); // revisit this, literal node issue. - array[0] = i; - return data; + array[0] = static_cast(int_imm->value); } + return data; } else if (token->token_type == TokenType::kFloat) { - DLDevice dev = {DLDeviceType::kDLCPU, 0}; auto float_imm = Downcast(token->data); auto data = NDArray::Empty({}, float_imm->dtype, dev); - auto array = reinterpret_cast(data->data); - // revisit this, literal node issue. - // TODO(@jroesch): bounds checking - float value = float_imm->value; - array[0] = value; + if (float_imm.dtype() == DataType::Float(64)) { + auto array = reinterpret_cast(data->data); + // revisit this, literal node issue. + array[0] = float_imm->value; + } else { + auto array = reinterpret_cast(data->data); + // revisit this, literal node issue. + array[0] = static_cast(float_imm->value); + } return data; } else { LOG(FATAL) << "internal error: should only call this function on numeric tokens"; diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index f8098cf941005..f31e3e369f3c3 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -172,37 +172,18 @@ struct Tokenizer { } Token ParseNumber(bool is_pos, bool is_float, std::string number) { - ICHECK(number.size() > 0) << "an empty string is an invalid number"; + ICHECK(number.size() > 0) << "an empty string is an invalid float"; - if (!is_float) { - auto token = NewToken(TokenType::kInteger); - size_t index = 0; - int64_t value = 0; - try { - value = std::stoll(number, &index); - } catch (const std::invalid_argument& err) { - this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`"); - } catch (const std::out_of_range& err) { - this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`"); - } - if (number.size() <= index) { - value = is_pos ? value : -value; - if (value > std::numeric_limits::max()) { - token->data = tvm::IntImm(DataType::Int(64), value); - } else { - token->data = tvm::IntImm(DataType::Int(32), value); - } - return token; - } + Token token = NewToken(is_float ? TokenType::kFloat : TokenType::kInteger); + size_t suffix_pos = number.rfind(is_float ? 'f' : 'i'); + if (suffix_pos == std::string::npos) { + suffix_pos = number.size(); + } + std::string literal_text = number.substr(0, suffix_pos); + std::string suffix; + if (suffix_pos < number.size()) { + suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); } - auto token = NewToken(TokenType::kFloat); - - auto suffix_pos = number.rfind("f"); - - auto literal_text = number.substr(0, suffix_pos); - - auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos); - int width = 32; if (suffix.size()) { @@ -215,11 +196,66 @@ struct Tokenizer { this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid numeric suffix `" << suffix << "`"); } + if (width != 32 && width != 64) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid numeric suffix `" << suffix << "'"); + width = 32; + } + } + + if (is_float) { + double value = 0.0; + size_t index = 0; + try { + value = stod(literal_text, &index); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } catch (const std::out_of_range& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } + if (index < literal_text.size()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid floating point number `" << literal_text << "`"); + } + value = is_pos ? value : -value; + if (width == 32 && (value < -std::numeric_limits::max() || + value > std::numeric_limits::max())) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "floating point number `" << literal_text << "` out of range for float32"); + } + token->data = tvm::FloatImm(DataType::Float(width), value); + } else { + int64_t value = 0; + size_t index = 0; + try { + value = std::stoll(literal_text, &index); + } catch (const std::invalid_argument& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } catch (const std::out_of_range& err) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } + if (index < literal_text.size()) { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "invalid integer number `" << literal_text << "`"); + } + value = is_pos ? value : -value; + if (width == 32 && (value < std::numeric_limits::min() || + value > std::numeric_limits::max())) { + if (suffix.empty()) { + // Without any i suffix the legacy behavior was to choose the smallest width. + width = 64; + } else { + this->diag_ctx.Emit(Diagnostic::Error(token->span) + << "integer number `" << literal_text << "` out of range for int32"); + } + } + token->data = tvm::IntImm(DataType::Int(width), value); } - double value = stod(literal_text); - value = is_pos ? value : -value; - token->data = tvm::FloatImm(DataType::Float(width), value); return token; } @@ -230,14 +266,13 @@ struct Tokenizer { } bool is_float = false; - - // Remove trailing floating point prefix. - if (More() && Peek() == 'f') { + if (More() && (Peek() == 'f' || Peek() == 'i')) { + is_float = Peek() == 'f'; + // Capture trailing width suffix ss << Next(); while (More() && IsNumeric(Peek())) { ss << Next(); } - is_float = true; } return ParseNumber(is_pos, is_float, ss.str()); } diff --git a/src/printer/doc.cc b/src/printer/doc.cc index f7d9fdfd7dfb3..10977d083b56c 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -53,9 +53,6 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode); class DocText : public DocAtom { public: explicit DocText(std::string str) { - if (str.find_first_of("\t\n") != str.npos) { - LOG(WARNING) << "text node: '" << str << "' should not have tab or newline."; - } data_ = runtime::make_object(str); } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 97231931ad88e..fb2b20fd59027 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -61,8 +61,17 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { } // default annotations if (annotate_ == nullptr) { - if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { - doc << " /* ty=" << Print(expr->checked_type()) << " */"; + if ((expr.as() || expr.as() || expr.as() || + expr.as() || expr.as()) && + (expr->checked_type_.defined() || expr->span.defined())) { + doc << " /*"; + if (expr->checked_type_.defined()) { + doc << " ty=" << Print(expr->checked_type()); + } + if (expr->span.defined()) { + doc << " span=" << PrintSpan(expr->span); + } + doc << " */"; } } else { std::string annotated_expr = annotate_(expr); @@ -219,7 +228,7 @@ Doc RelayTextPrinter::AllocVar(const Var& var) { name = "v" + name; } Doc val = GetUniqueName("%" + name); - memo_[var] = val; + memo_[var] = val; // Referential occurrences will not include the following. if (!var->virtual_device()->IsFullyUnconstrained()) { val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}"; } @@ -345,6 +354,8 @@ Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) { std::ostringstream os; if (dtype == DataType::Int(32)) { os << value; + } else if (dtype == DataType::Int(64)) { + os << value << "i64"; } else if (dtype == DataType::Float(32)) { os << value << 'f'; } else if (dtype == DataType::Float(64)) { @@ -540,9 +551,6 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { return doc; } else { doc << "(" << Doc::Concat(args) << ")"; - if (op->span.defined()) { - doc << " /* " << PrintSpan(op->span) << " */"; - } return doc; } } @@ -977,7 +985,7 @@ Doc RelayTextPrinter::PrintSpan(const Span& span) { Doc doc; const auto* span_node = span.as(); ICHECK(span_node); - doc << span_node->source_name->name; + doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column; return doc; } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 99f0517d1b7fa..61c2d74856112 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -336,8 +336,8 @@ class RelayBuildModule : public runtime::ModuleNode { backend::BindParamsInModule(relay_module, params_); - Array pass_seqs = GetPassPrefix( - /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); + Array pass_seqs = + GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false); transform::PassContext pass_ctx = PassContext::Current(); if (config_->optional_homogeneous_target.defined()) { diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index b12da1ac62cba..b4afce70e382c 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -785,11 +785,12 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { std::pair> GenCutlassFunc(const Function& func) { ICHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. - auto sid = GetExtSymbol(func); + std::string sid = GetExtSymbol(func); const auto* attrs = func->attrs.as(); ICHECK(attrs != nullptr); const auto dict = attrs->dict; CodegenCutlass builder(sid, dict); + VLOG(1) << "Creating cutlass C code for '" << sid << "' from:\n" << PrettyPrint(func); auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); return {sid, {}}; @@ -834,6 +835,7 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; + VLOG(1) << "Creating cutlass CSource runtime::Module for '" << sym << "'"; return (*pf)(code, "cu", Array{sym}, variables); } @@ -843,15 +845,13 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { }; // CutlassModuleCodegen /*! - * \brief The external cutlass compiler/codegen tool. It takes a Relay - * expression/module and compile it into a runtime module. + * \brief Compile a primitive function using the CUTLASS toolchain. */ -runtime::Module CutlassCompiler(const ObjectRef& ref) { - CutlassModuleCodegen cutlass; - return cutlass.CreateCSourceModule(ref); +runtime::Module CreateCSourceModule(Function function) { + return CutlassModuleCodegen().CreateCSourceModule(function); } -TVM_REGISTER_GLOBAL("relay.ext.cutlass").set_body_typed(CutlassCompiler); +TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule); } // namespace contrib } // namespace relay diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index 431be8ed3dc31..ef5591b1b8f0e 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -252,6 +252,7 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) { auto param_names = serializer.GetParams(); const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; + VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'"; runtime::Module lib = (*pf)(func_name, graph_json, param_names); return lib; } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 4209b0a8bbe74..deb6d780d5e3a 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -773,7 +773,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } else { // The target corresponding to the call_node expression's annotation. VirtualDevice virtual_device = GetVirtualDevice(GetRef(call_node)); - ICHECK(!virtual_device->IsFullyUnconstrained()); + ICHECK(!virtual_device->IsFullyUnconstrained()) << PrettyPrint(GetRef(call_node)); target = virtual_device->target; ICHECK(target.defined()); } diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 2bddf75566013..8f72bd1a60740 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -27,9 +27,7 @@ #include #include - -#include "te_compiler.h" -#include "tvm/runtime/ndarray.h" +#include namespace tvm { namespace relay { @@ -203,7 +201,7 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata( TVM_REGISTER_NODE_TYPE(ExecutorCodegenMetadataNode); -Array GetPassPrefix(bool is_homegeneous, bool is_vm) { +Array GetPassPrefix(bool is_homogeneous, bool is_vm) { Array pass_seqs; // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton // by most passes there's little utility in including this now. Plus we'd need to only do @@ -216,7 +214,7 @@ Array GetPassPrefix(bool is_homegeneous, bool is_vm) { pass_seqs.push_back(relay::qnn::transform::Legalize()); // Legalize pass is restricted to homogeneous execution for now. - if (is_homegeneous) { + if (is_homogeneous) { pass_seqs.push_back(transform::Legalize()); } @@ -252,7 +250,7 @@ Array GetPassPrefix(bool is_homegeneous, bool is_vm) { pass_seqs.push_back(transform::CanonicalizeOps()); // Alter layout transformation is currently only applied to homogeneous execution. - if (is_homegeneous) { + if (is_homogeneous) { if (!is_vm) { pass_seqs.push_back(transform::InferType()); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a9035b9ae5a4b..b1b3f1ad6d6c7 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -512,11 +512,11 @@ inline bool IsMetaScheduleEnabled() { * difference. This function unifies the shared optimization pass prefix between vm and graph * runtime, and returns the pass prefix given the backend type. * - * \param is_homogenous True if all primitives are to be executed on the same device and target. + * \param is_homogeneous True if all primitives are to be executed on the same device and target. * \param is_vm True if passes are to be used for the vm executor. * \return An array of passes. */ -Array GetPassPrefix(bool is_homogenous, bool is_vm); +Array GetPassPrefix(bool is_homogeneous, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b634091543505..ad1591248d2a5 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -49,6 +49,8 @@ #include "../../../driver/internal_driver_api.h" #include "../../../target/metadata_module.h" #include "../../../target/source/codegen_source_base.h" +#include "../../collage/capture_index_in_spans.h" +#include "../../collage/collage_partitioner.h" #include "../../op/annotation/annotation.h" #include "../../op/memory/device_copy.h" #include "../../op/op_common.h" @@ -827,8 +829,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "lower") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 3); - this->Lower(args[0], args[1], args[2]); + ICHECK_EQ(args.num_args, 2); + this->Lower(args[0], args[1]); }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -855,8 +857,8 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtrOptimizeModule(args[0], args[1], args[2]); + ICHECK_EQ(args.num_args, 2); + *rv = this->OptimizeModule(args[0], args[1]); }); } else { LOG(FATAL) << "Unknown packed function: " << name; @@ -868,10 +870,10 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) { +void VMCompiler::Lower(IRModule mod, Array targets) { VLOG_CONTEXT << "VM Lower"; exec_ = make_object(); - config_ = CompilationConfig(PassContext::Current(), std::move(targets), std::move(target_host)); + config_ = CompilationConfig(PassContext::Current(), std::move(targets)); // The first device is always for the host. CHECK(context_.virtual_devices_.empty()); @@ -1022,9 +1024,8 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& hos return transform::Sequential(std::move(pass_seqs)); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, - const Target& target_host) { - config_ = CompilationConfig(PassContext::Current(), targets, target_host); +IRModule VMCompiler::OptimizeModule(IRModule mod, Array targets) { + config_ = CompilationConfig(PassContext::Current(), targets); // The first device always corresponds to the host. CHECK(context_.virtual_devices_.empty()); context_.virtual_devices_.push_back(config_->host_virtual_device); @@ -1034,14 +1035,22 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, } IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { - VLOG_CONTEXT << "VM Optimize"; backend::BindParamsInModule(mod, params_); + Array pass_seqs; + + // ############# Collage ############### + pass_seqs.push_back(collage::CaptureIndexInSpans()); + pass_seqs.push_back(collage::CollagePartition(config_)); + // ##################################### - Array pass_seqs = relay::backend::GetPassPrefix( - /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); + for (const auto& pass : relay::backend::GetPassPrefix( + /*is_homogeneous=*/config_->primitive_targets.size() == 1, /*is_vm=*/true)) { + pass_seqs.push_back(pass); + } // Always plan devices so the remaining passes don't need to distinguish homogeneous vs - // hetrogeneous execution. + // heterogeneous execution. + // TODO(mbs): Move to before CallagePartition pass_seqs.push_back(transform::PlanDevices(config_)); pass_seqs.push_back(transform::FuseOps()); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 906e5148b593b..2aa3b42832e68 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -116,7 +116,7 @@ class VMCompiler : public runtime::ModuleNode { * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, TargetMap targets, Target target_host); + void Lower(IRModule mod, Array targets); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); @@ -132,7 +132,7 @@ class VMCompiler : public runtime::ModuleNode { * * \return The optimized IRModule. */ - IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host); + IRModule OptimizeModule(IRModule mod, Array targets); IRModule OptimizeModuleImpl(IRModule mod); diff --git a/src/relay/collage/README.md b/src/relay/collage/README.md new file mode 100644 index 0000000000000..c7f741618fc92 --- /dev/null +++ b/src/relay/collage/README.md @@ -0,0 +1,9 @@ +The `CollagePartition` pass for finding optimal partitionings of Relay models. + +See the [RFC](https://github.com/mbs-octoml/mbs-tvm-rfcs/blob/mbs-rfcs-collage/rfcs/xxxx-collage.md). + +Based on: +> *Collage: Automated Integration of Deep Learning Backends* +> Byungsoo Jeon, Sunghyun Park, Peiyuan Liao, Sheng Xu, Tianqi Chen, Zhihao Jia + +CAUTION: This is a prototype, do not use in prod. diff --git a/src/relay/collage/candidate_partition.cc b/src/relay/collage/candidate_partition.cc new file mode 100644 index 0000000000000..abe897831a97e --- /dev/null +++ b/src/relay/collage/candidate_partition.cc @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_partition.cc + * \brief A potential partition in the search. + */ + +#include "./candidate_partition.h" + +#include + +#include "./partition_rule.h" +#include "./partition_spec.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +PartitionSpec CandidatePartitionNode::partition_spec() const { + return Downcast(spec_); +} + +std::string CandidatePartitionNode::partition_spec_name() const { + return Downcast(spec_)->spec_name_; +} + +std::string CandidatePartitionNode::ToString() const { + std::ostringstream os; + os << "{rule_name=" << rule_name_; + os << ",sub_graph=" << sub_graph_->ToString(); + os << ",spec_name=" << partition_spec_name(); + if (target_.defined()) { + os << ",target=" << target_->ToDebugString(); + } + if (!cost_.is_unknown()) { + os << ",cost=" << cost_.ToString(); + } + os << "}"; + return os.str(); +} + +Cost CandidatePartitionNode::EstimatedCost(const DataflowGraph& dataflow_graph, + CostEstimator* cost_estimator, + NameSupply& name_supply) const { + ICHECK(target_.defined()); + if (cost_.is_unknown()) { + VLOG_CONTEXT << "spec " << partition_spec_name(); + Function extracted_function = sub_graph_->ExtractAsFunction(dataflow_graph, name_supply); + extracted_function = Downcast(transform::InferTypeExpr(extracted_function)); + VLOG(1) << "Validating function:\n" << PrettyPrint(extracted_function); + String error = partition_spec()->validate_sub_graph_func_(extracted_function); + if (!error.empty()) { + cost_ = Cost::Invalid(); + VLOG(1) << "Unable to rewrite function: " << error; + } else { + VLOG(1) << "Estimating cost of:\n" << PrettyPrint(extracted_function); + cost_ = cost_estimator->CachedEstimate(IRModule::FromExpr(extracted_function), target_); + VLOG(1) << "Estimated cost, candidate now " << ToString(); + } + } else { + VLOG(1) << "Reusing cost cached in candidate"; + } + return cost_; +} + +CandidatePartition::CandidatePartition(String rule_name, SubGraph sub_graph, + ObjectRef /* actually PartitionSpec */ spec, Target target, + Cost cost) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_graph_ = std::move(sub_graph); + node->spec_ = std::move(spec); + node->target_ = std::move(target); + node->cost_ = cost; + data_ = std::move(node); +} + +CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name) { + if (rule_name == candidate->rule_name_) { + return candidate; + } + auto* node = candidate.CopyOnWrite(); + node->rule_name_ = std::move(rule_name); + return GetRef(node); +} + +CandidatePartition WithTarget(CandidatePartition candidate, Target target) { + if (target == candidate->target_) { + return candidate; + } + auto* node = candidate.CopyOnWrite(); + node->target_ = std::move(target); + return GetRef(node); +} + +CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph) { + if (sub_graph == candidate->sub_graph_) { + return candidate; + } + auto* node = candidate.CopyOnWrite(); + node->sub_graph_ = std::move(sub_graph); + return GetRef(node); +} + +bool CandidatePartition::operator<(const CandidatePartition& that) const { + // Order lexicographically on sub-graphs. + if (*get()->sub_graph_.get() < *that->sub_graph_.get()) { + return true; + } + if (*that->sub_graph_.get() < *get()->sub_graph_.get()) { + return false; + } + // Break ties by rule name. + return get()->rule_name_ < that->rule_name_; +} + +bool CandidatePartition::AreTouching(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const { + return get()->target_ == that->target_ && // ok if both are null + get()->sub_graph_.AreTouching(dataflow_graph, that->sub_graph_); +} + +CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const { + ICHECK_EQ(get()->spec_, that->spec_); + ICHECK_EQ(get()->target_, that->target_); // may be null + return CandidatePartition(UnionLabels(get()->rule_name_, that->rule_name_), + get()->sub_graph_.DisjointUnion(dataflow_graph, that->sub_graph_), + get()->spec_, get()->target_, get()->cost_ + that->cost_); +} + +/*static*/ +CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, + std::vector candidates) { + ICHECK_GT(candidates.size(), 1); + CandidatePartition result = candidates.front(); + for (size_t i = 1; i < candidates.size(); ++i) { + result = result.DisjointUnion(dataflow_graph, candidates[i]); + } + return result; +} + +/*static*/ +Expr CandidatePartition::ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + const std::vector& candidates, + NameSupply& name_supply) { + std::vector sub_graphs; + for (const auto& candidate : candidates) { + sub_graphs.emplace_back(candidate->sub_graph_); + } + return SubGraph::ParallelRewrite(dataflow_graph, expr, sub_graphs, name_supply); +} + +/*static*/ +std::vector CandidatePartition::MaxCoalesce( + const DataflowGraph& dataflow_graph, std::vector candidates) { + // Sort the candidates by their first-inside index. + std::sort(candidates.begin(), candidates.end(), + [](const CandidatePartition& left, const CandidatePartition& right) { + return left->sub_graph_->first_inside_index_ < right->sub_graph_->first_inside_index_; + }); + std::vector result; + while (!candidates.empty()) { + size_t n = 1; + // Take the next original candidate. + CandidatePartition base = candidates.front(); + candidates.erase(candidates.begin()); + // Union as many remaining original candidates as possible. + for (auto itr = candidates.begin(); itr != candidates.end(); /*no-op*/) { + CandidatePartition rhs = *itr; + if (base.AreTouching(dataflow_graph, rhs)) { + base = base.DisjointUnion(dataflow_graph, rhs); + ++n; + itr = candidates.erase(itr); + } else { + ++itr; + } + } + if (n > 1) { + VLOG(1) << "Coalesced " << n << " candidates to " << base->ToString(); + } + result.push_back(base); + } + return result; +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/candidate_partition.h b/src/relay/collage/candidate_partition.h new file mode 100644 index 0000000000000..e55d19d54be76 --- /dev/null +++ b/src/relay/collage/candidate_partition.h @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_partition.cc + * \brief A potential partition in the search. + */ + +#ifndef SRC_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ +#define SRC_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ + +#include + +#include "./cost.h" +#include "./cost_estimator.h" +#include "./name_supply.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +class PartitionSpec; + +/*! + * \brief A candidate partition w.r.t. the body of an overall Relay expression. + * + * We represent the partition as a sub-graph. This means not only can we represent the scope + * of Relay sub-expressions intended for a particular partition (or kernel), but we can also + * represent various conventions for encoding how the operators in the partition should be + * tagged for downstream processing. + */ +class CandidatePartitionNode : public Object { + public: + CandidatePartitionNode() = default; + + /*! + * \brief Combination of all the partition rule names which produced this candidate. + * For debugging and explainability. + */ + String rule_name_; + + /*! + * \brief The sub-graph of the overall expression matched by the partition rule. + */ + SubGraph sub_graph_; + + /*! + * \brief The partition specification which produced this candidate. + */ + ObjectRef /* actually PartitionSpec */ spec_; + + /*! + * \brief The target for which to compile the above function. + * + * Will be null for intermediate candidates, and is determined only by + * \p PartitionSpec::AllCandidates as the candidates found by it's top-level \p PartitionRule + * are 'finalized' for insertion into \p CandidatePartitionIndex. + */ + Target target_; + + /*! + * \brief The (cached) cost of the partition. + * + * Initially Cost::Unknown, calculated and cached by EstimateCost. + */ + mutable Cost cost_ = Cost::Unknown(); + + /*! + * \brief Returns the partition specification which produced this candidate. + */ + PartitionSpec partition_spec() const; + + /*! + * \brief Returns the name of the partition specification which produced this candidate. + */ + std::string partition_spec_name() const; + + /*! + * \brief Return the estimated cost of the candidate partition, using \p cost_estimator if + * the cost is not already known. Internally cached. + */ + Cost EstimatedCost(const DataflowGraph& dataflow_graph, CostEstimator* cost_estimator, + NameSupply& name_supply) const; + + std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.CandidatePartition"; + TVM_DECLARE_FINAL_OBJECT_INFO(CandidatePartitionNode, Object); +}; + +class CandidatePartition : public ObjectRef { + public: + CandidatePartition(String rule_name, SubGraph sub_graph, + ObjectRef /* actually PartitionSpec */ spec, Target target = {}, + Cost cost = Cost::Unknown()); + + + bool operator<(const CandidatePartition& that) const; + + /*! + * \brief Returns true if this and \p that candidate are disjoint, have the same (or no) target, + * and touch. This does not imply the \p DisjointUnion of this and that will be valid. For + * example, the result may be too deep or have too many outputs. + */ + bool AreTouching(const DataflowGraph& dataflow_graph, const CandidatePartition& that) const; + + /*! + * \brief Returns the disjoint union of this and \p that. + */ + CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const; + + /*! + * \brief Returns the disjoint union of all \p candidates. + */ + static CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph, + std::vector candidates); + + /*! + * \brief Returns \p expr rewritten to apply all the partitions implied by \p candidates. + * The candidates can be in any order but must be disjoint. + */ + static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + const std::vector& candidates, + NameSupply& name_supply); + + /*! + * Eagerly merge all touching candidates for the same target. The candidates must be disjoint + * and have their Targets filled in. This is typically called on the optimal list of candidate + * partitions found by the Collage search in order to remove unnecessary partition boundaries. + * Ideally the search would never produce such candidates however to keep the search space + * manageable Collage may only consider candidate partitions up to a particular depth. + */ + static std::vector MaxCoalesce(const DataflowGraph& dataflow_graph, + std::vector candidates); + + TVM_DEFINE_OBJECT_REF_METHODS(CandidatePartition, ObjectRef, CandidatePartitionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CandidatePartitionNode); +}; + +CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name); +CandidatePartition WithTarget(CandidatePartition candidate, Target target); +CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph); + +struct CandidatePartitionHash { + size_t operator()(const CandidatePartition& candidate) const { + return candidate->sub_graph_->hash(); + } +}; + +struct CandidatePartitionEquals { + bool operator()(const CandidatePartition& left, const CandidatePartition& right) const { + return *left->sub_graph_.get() == *right->sub_graph_.get(); + } +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ diff --git a/src/relay/collage/candidate_partition_index.cc b/src/relay/collage/candidate_partition_index.cc new file mode 100644 index 0000000000000..c8fc4fcae51ea --- /dev/null +++ b/src/relay/collage/candidate_partition_index.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/collage/candidate_partition_index.h + * \brief Index for finding relevant candidate partitions for a particular search state. + */ + +#include "./candidate_partition_index.h" + +#include "./prune_candidates.h" + +namespace tvm { +namespace relay { +namespace collage { + +CandidatePartitionIndex::CandidatePartitionIndex( + const std::unordered_map* virtual_devices, + DataflowGraph* dataflow_graph) + : virtual_devices_(virtual_devices), + dataflow_graph_(dataflow_graph), + first_inside_index_to_candidates_(dataflow_graph->size()) {} + +void CandidatePartitionIndex::Index(const Array& partition_specs) { + std::vector candidates = Collect(partition_specs); + candidates = PruneCandidates(*dataflow_graph_, candidates); + // Index the candidates by their first inside index. + for (auto& candidate : candidates) { + VLOG(1) << "Indexing candidate " << candidate->ToString(); + first_inside_index_to_candidates_[candidate->sub_graph_->first_inside_index_].emplace_back( + candidate); + } + size_ = candidates.size(); +} + +bool CandidatePartitionIndex::IsCompatibleWithVirtualDevice(const CandidatePartition& candidate) { + ICHECK(candidate->target_.defined()); + for (PostDfsIndex index : candidate->sub_graph_->inside_) { + const ExprNode* sub_expr_node = dataflow_graph_->index_to_node(index)->node_ref_; + auto itr = virtual_devices_->find(sub_expr_node); + ICHECK(itr != virtual_devices_->end()); + if (!itr->second->target.defined()) { + // No constraint. + continue; + } + if (!candidate->target_.IsRefinementOf(itr->second->target)) { + return false; + } + } + return true; +} + +std::vector CandidatePartitionIndex::Collect( + const Array& partition_specs) { + VLOG_CONTEXT << "collecting"; + std::vector result; + for (const auto& spec : partition_specs) { + VLOG_CONTEXT << "spec " << spec->spec_name_; + VLOG(1) << "collecting candidates"; + std::vector candidates = spec->AllCandidates(*dataflow_graph_); + for (auto& candidate : candidates) { + if (!IsCompatibleWithVirtualDevice(candidate)) { + VLOG(1) << "Ignoring candidate " << candidate->ToString() + << " since incompatible with existing virtual device assignments for sub-graph"; + continue; + } + result.push_back(candidate); + } + } + VLOG(1) << "Found " << result.size() << " candidates"; + return result; +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/candidate_partition_index.h b/src/relay/collage/candidate_partition_index.h new file mode 100644 index 0000000000000..bd3d0c1650a27 --- /dev/null +++ b/src/relay/collage/candidate_partition_index.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/collage/candidate_partition_index.h + * \brief Index for finding relevant candidate partitions for a particular search state. + */ +#ifndef TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_ +#define TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_ + +#include + +#include "partition_spec.h" +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Collects and indexes all the candidate partitions for the overall expression. This index + * is used during partitioning search to find the next valid candidate partition to explore from the + * current search state. We do not yet attempt to estimate the cost of each candidate partition, and + * when we do so during the search we may discover it to be infeasible. + */ +class CandidatePartitionIndex { + public: + CandidatePartitionIndex(const std::unordered_map* virtual_devices, + DataflowGraph* dataflow_graph); + + /*! \brief Constructs the index. */ + void Index(const Array& partition_specs); + + /*! \brief Returns all the candidates which may begin at \p index. */ + const std::vector& candidates_at(PostDfsIndex index) const { + ICHECK_LT(index, dataflow_graph_->size()); + return first_inside_index_to_candidates_[index]; + } + + size_t size() const { return size_; } + + private: + /*! + * \brief Returns true if \p candidate's target is a refinement of the target's implied by + * the sub-expressions inside it. + */ + bool IsCompatibleWithVirtualDevice(const CandidatePartition& candidate); + + /*! \brief Returns all valid candidates found from \p partition_specs. */ + std::vector Collect(const Array& partition_specs); + + /*! + * \brief The \p VirtualDevice for every sub-expression in the overall expression. Needed to + * ensure candidates do not contradict the target/device placement already determined by + * device planning. + */ + const std::unordered_map* virtual_devices_; + + /*! \brief Dataflow graph for overall expression. */ + DataflowGraph* dataflow_graph_; + + /*! + * \brief Maps post-dfs indexes to the all the candidates which have that as their first inside + * index, and which should be considered in the Collage search. + */ + std::vector> first_inside_index_to_candidates_; + + /*! \brief Number of entries in above. */ + size_t size_ = 0; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_ diff --git a/src/relay/collage/candidate_set.cc b/src/relay/collage/candidate_set.cc new file mode 100644 index 0000000000000..af8e95335f98c --- /dev/null +++ b/src/relay/collage/candidate_set.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_set.cc + * \brief Collects a set of candidate partitions. + */ + +#include "./candidate_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +void CandidateSet::Add(const DataflowGraph& dataflow_graph, + const CandidatePartition& new_candidate) { + if (seen.count(new_candidate)) { + VLOG(2) << "already seen candidate, ignoring"; + return; + } + seen.emplace(new_candidate); + candidates_to_add.emplace_back(new_candidate); +} + +void CandidateSet::Remove(const CandidatePartition& old_candidate) { + ICHECK(seen.count(old_candidate)); + VLOG(1) << "Removing " << old_candidate->ToString(); + candidates_to_remove.emplace_back(old_candidate); +} + +bool CandidateSet::PrepareForNextRound() { + size_t init_size = current_candidates.size(); + for (const auto& candidate_to_remove : candidates_to_remove) { + current_candidates.erase( + std::remove(current_candidates.begin(), current_candidates.end(), candidate_to_remove), + current_candidates.end()); + } + size_t num_removed = init_size - current_candidates.size(); + candidates_to_remove.clear(); + first_new_index = current_candidates.size(); + for (const auto& new_candidate : candidates_to_add) { + current_candidates.push_back(new_candidate); + } + size_t num_added = candidates_to_add.size(); + candidates_to_add.clear(); + VLOG(1) << "removed " << num_removed << " and added " << num_added << " candidates"; + return num_removed + num_added > 0; +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/candidate_set.h b/src/relay/collage/candidate_set.h new file mode 100644 index 0000000000000..d35f6c9adc0d6 --- /dev/null +++ b/src/relay/collage/candidate_set.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_set.h + * \brief Collects a set of candidate partitions. + */ + +#ifndef SRC_RELAY_COLLAGE_CANDIDATE_SET_H_ +#define SRC_RELAY_COLLAGE_CANDIDATE_SET_H_ + +#include "./candidate_partition.h" +#include "./dataflow_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Holds a vector of current candidates and the additions/removals to apply to them. + */ +struct CandidateSet { + CandidateSet() = default; + + /*! + * \brief Schedule \p new_candidate for addition before the next round (unless it is not valid). + */ + void Add(const DataflowGraph& dataflow_graph, const CandidatePartition& new_candidate); + + /*! \brief Schedule \p old_candidate for removal before the next round. */ + void Remove(const CandidatePartition& old_candidate); + + /*! + * \brief Update \p current_candidates and \p first_new_index. Return false if no + * new candidates were added, in which case we have reached a fixed point. + */ + bool PrepareForNextRound(); + + size_t size() const { return current_candidates.size(); } + + CandidatePartition operator[](size_t i) const { + ICHECK_LT(i, current_candidates.size()); + return current_candidates[i]; + } + CandidatePartition at(size_t i) const { return (*this)[i]; } + + /*! + * \brief Index of first candidate in current_candidates added in last round. This can be used to + * avoid considering candidates or candidate combinations which have already been considered in an + * earlier round. + */ + size_t first_new_index = 0; + /*! \brief Candidates gathered in previous rounds. */ + std::vector current_candidates; + /*! \brief New candidates gathered in the current round. */ + std::vector candidates_to_add; + /*! \brief Existing candidates to remove before starting the next round. */ + std::vector candidates_to_remove; + /*! \brief Which candidates have been seen so far and should not be added again. */ + std::unordered_set seen; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_CANDIDATE_SET_H_ diff --git a/src/relay/collage/capture_index_in_spans.cc b/src/relay/collage/capture_index_in_spans.cc new file mode 100644 index 0000000000000..55839358c805c --- /dev/null +++ b/src/relay/collage/capture_index_in_spans.cc @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/collage/capture_index_in_spans.cc + * \brief Pass to set spans to capture the post-dfs index of every node. For debuggin only. + */ + +#include "./capture_index_in_spans.h" + +#include + +#include "../ir/indexed_graph.h" +#include "dataflow_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +namespace { +class CaptureIndexInSpansRewriter : public ExprMutator { + public: + explicit CaptureIndexInSpansRewriter(const IndexedGraph* indexed_graph) + : source_name_(SourceName::Get("index")), indexed_graph_(indexed_graph) {} + + private: + Expr VisitExpr_(const VarNode* var_node) override { + auto var = GetRef(var_node); + MakeSpan(var); // for side effects + return var; + } + + Expr VisitExpr_(const ConstantNode* constant_node) override { + auto constant = GetRef(constant_node); + return WithFields(constant, {}, {}, MakeSpan(constant)); + } + + Expr VisitExpr_(const GlobalVarNode* global_var_node) override { + auto global_var = GetRef(global_var_node); + MakeSpan(global_var); // for side effects + return global_var; + } + + Expr VisitExpr_(const OpNode* op_node) override { + auto op = GetRef(op_node); + MakeSpan(op); // for side effects + return op; + } + + Expr VisitExpr_(const TupleNode* tuple_node) override { + auto tuple = GetRef(tuple_node); + auto new_tuple = Downcast(ExprMutator::VisitExpr_(tuple_node)); + return WithFields(new_tuple, {}, {}, MakeSpan(tuple)); + } + + Expr VisitExpr_(const FunctionNode* function_node) override { + auto function = GetRef(function_node); + // Don't recurse into the bodies of primitive functions. + // CAUTION: This is why we can't just use an ExprRewriter. + Function new_function = function_node->HasNonzeroAttr(attr::kPrimitive) + ? function + : Downcast(ExprMutator::VisitExpr_(function_node)); + return WithFields(new_function, {}, {}, {}, {}, {}, {}, MakeSpan(function)); + } + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = GetRef(call_node); + auto new_call = Downcast(ExprMutator::VisitExpr_(call_node)); + return WithFields(new_call, {}, {}, {}, {}, {}, MakeSpan(call)); + } + + Expr VisitExpr_(const LetNode* let_node) override { + auto let = GetRef(let_node); + auto new_let = Downcast(ExprMutator::VisitExpr_(let_node)); + return WithFields(new_let, {}, {}, {}, {}, MakeSpan(let)); + } + + Expr VisitExpr_(const IfNode* if_node) override { + auto ife = GetRef(if_node); + auto new_ife = Downcast(ExprMutator::VisitExpr_(if_node)); + return WithFields(new_ife, {}, {}, {}, {}, MakeSpan(ife)); + } + + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) override { + auto tuple_get_item = GetRef(tuple_get_item_node); + auto new_tuple_get_item = Downcast(ExprMutator::VisitExpr_(tuple_get_item_node)); + return WithFields(new_tuple_get_item, {}, {}, {}, MakeSpan(tuple_get_item)); + } + + Expr VisitExpr_(const RefCreateNode* ref_create_node) override { + auto ref_create = GetRef(ref_create_node); + auto new_ref_create = Downcast(ExprMutator::VisitExpr_(ref_create_node)); + return WithFields(new_ref_create, {}, {}, MakeSpan(ref_create)); + } + + Expr VisitExpr_(const RefReadNode* ref_read_node) override { + auto ref_read = GetRef(ref_read_node); + auto new_ref_read = Downcast(ExprMutator::VisitExpr_(ref_read_node)); + return WithFields(new_ref_read, {}, {}, MakeSpan(ref_read)); + } + + Expr VisitExpr_(const RefWriteNode* ref_write_node) override { + auto ref_write = GetRef(ref_write_node); + auto new_ref_write = Downcast(ExprMutator::VisitExpr_(ref_write_node)); + return WithFields(new_ref_write, {}, {}, {}, MakeSpan(ref_write)); + } + + Expr VisitExpr_(const ConstructorNode* constructor_node) override { + auto constructor = GetRef(constructor_node); + MakeSpan(constructor); // for side effects + return constructor; + } + + Expr VisitExpr_(const MatchNode* match_node) override { + auto match = GetRef(match_node); + auto new_match = Downcast(ExprMutator::VisitExpr_(match_node)); + return WithFields(new_match, {}, {}, {}, MakeSpan(match)); + } + + Span MakeSpan(const Expr& expr) { + auto node = indexed_graph_->item_to_node(expr); + PostDfsIndex node_index = node->index_; + PostDfsIndex dominator_index = node->dominator_parent_ ? node->dominator_parent_->index_ : -1; + Span span(source_name_, /*line=*/node_index, /*end_line=*/node_index, + /*column=*/dominator_index, /*end_column=*/dominator_index); + ICHECK_EQ(index_, node_index) + << "expecting visit order to match dataflow graph's post-dfs index order at expression:\n" + << PrettyPrint(expr); + index_++; + return span; + } + + SourceName source_name_; + const IndexedGraph* indexed_graph_; + PostDfsIndex index_ = 0; +}; + +} // namespace + +/*! + * Captures the post-dfs index and dominator post-dfs index of every node in it's span, in the form + * "index:: + * For debugging only. + */ +transform::Pass CaptureIndexInSpans() { + auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { + std::unique_ptr> indexed_graph = CreateIndexedGraph(f); + CaptureIndexInSpansRewriter rewriter(indexed_graph.get()); + return Downcast(rewriter.VisitExpr(f)); + }; + return transform::CreateFunctionPass(pass_func, 0, "CaptureIndexInSpans", {}); +}; + +TVM_REGISTER_GLOBAL("relay.collage.capture_index_in_spans").set_body_typed(CaptureIndexInSpans); + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/capture_index_in_spans.h b/src/relay/collage/capture_index_in_spans.h new file mode 100644 index 0000000000000..a8c8ca169e9b2 --- /dev/null +++ b/src/relay/collage/capture_index_in_spans.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/collage/capture_index_in_spans.h + * \brief Pass to set spans to capture the post-dfs index of every node. For debuggin only. + */ +#ifndef TVM_RELAY_COLLAGE_CAPTURE_INDEX_IN_SPANS_H_ +#define TVM_RELAY_COLLAGE_CAPTURE_INDEX_IN_SPANS_H_ + +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * Captures the post-dfs index and dominator post-dfs index of every node in it's span, in the form + * "index:: + * For debugging only. + */ +transform::Pass CaptureIndexInSpans(); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_CAPTURE_INDEX_IN_SPANS_H_ diff --git a/src/relay/collage/collage_partitioner.cc b/src/relay/collage/collage_partitioner.cc new file mode 100644 index 0000000000000..7f04313ca1b78 --- /dev/null +++ b/src/relay/collage/collage_partitioner.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/collage_partitioner.cc + * \brief Search for an optimal partitioning of a Relay model. + */ + +#include "./collage_partitioner.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../ir/dataflow_matcher_impl.h" +#include "./candidate_partition.h" +#include "./candidate_partition_index.h" +#include "./cost.h" +#include "./cost_estimator.h" +#include "./gather_partition_specs.h" +#include "./name_supply.h" +#include "./partition_rule.h" +#include "./partition_spec.h" +#include "./priority_queue.h" +#include "./recover_virtual_device_map.h" +#include "./sub_graph.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { +namespace { + +TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.enable_collage", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.autotvm_log_filename", String); + +/*! + * \brief Represents the overall expression after some number of non-overlapping candidate + * partitions have been applied. + */ +class SearchState { + public: + explicit SearchState(IndexSet covered) : covered_(std::move(covered)) {} + + /*! + * \brief Order states by increasing best cost, breaking ties by lexicographic order on + * the covering sub graph. + */ + bool operator<(const SearchState& that) const { + return std::tie(best_cost_, covered_) < std::tie(that.best_cost_, that.covered_); + } + + const IndexSet& covered() const { return covered_; } + + std::string ToString() const { + std::ostringstream os; + os << "State("; + os << "covered=" << covered_.ToString(); + os << ",best_cost=" << best_cost_.ToString(); + if (best_candidate_.defined()) { + os << ",best_candidate=" << best_candidate_->ToString(); + } + os << ")"; + return os.str(); + } + + private: + /*! \brief Which nodes of overall expression have been placed on all paths to this state. */ + IndexSet covered_; + /*! \brief Predecessor state for sequence of candidates reaching this state with least + * cost. Null if initial search state. */ + SearchState* pred_state_ = nullptr; + /*! + * \brief Cost of reaching this state using placement implied by path given by pred_state fields. + * Includes estimated/measured cost of all candidates plus any candidate launch penalty. + * Initially invalid cost. + */ + Cost best_cost_ = Cost::Invalid(); + /*! \brief Candidate partition selected in transition from pred_state to this state. */ + CandidatePartition best_candidate_; + + friend class Partitioner; +}; + +struct CompareSearchStatePtrs { + bool operator()(const SearchState* left, const SearchState* right) const { + return *left < *right; + } +}; + +struct EqualSearchStatePtrs { + bool operator()(const SearchState* left, const SearchState* right) const { + return left->covered() == right->covered(); + } +}; + +/*! + * \brief Finds the optimal partitioning of an expression to candidate partitions. + * Though no candidate partitions overlap, it is possible some sub-expressions end up in + * no candidate. Those sub-expressions must be evaluated by the host executor (eg VM). + */ +class Partitioner { + public: + explicit Partitioner(Array partition_specs, + const std::unordered_map* virtual_devices) + : partition_specs_(std::move(partition_specs)), virtual_devices_(virtual_devices) {} + + Expr Partition(const Expr& expr) { + // Establish core data structures. + dataflow_graph_ = std::make_unique(expr); + name_supply_ = std::make_unique("collage"); + VLOG(1) << "Created dataflow graph with " << dataflow_graph_->size() << " nodes"; + + // Build the candidate index. This is where all the partition rules are invoked . + CandidatePartitionIndex index(virtual_devices_, dataflow_graph_.get()); + index.Index(partition_specs_); + + // Setup initial state. + SearchState* init_state = GetState(IndexSet(dataflow_graph_->size())); + init_state->best_cost_ = Cost::Zero(); + pq_.Push(init_state); + + size_t num_candidates = 0; + + VLOG(1) << "#### Commencing Collage search over " << index.size() << " candidates ####"; + while (!pq_.empty()) { + SearchState* curr_state = pq_.Pop(); + VLOG(1) << "Looking at state " << curr_state->covered_.ToString(); + PostDfsIndex next_index = curr_state->covered_.FirstOutsideIndex(); + + if (next_index >= dataflow_graph_->size()) { + // The entire expression has been explored. Collect the candidates on the optimal path. + VLOG(1) << "#### Finished Collage search after exploring " << num_candidates + << " candidates ####"; + VLOG(1) << "----------------------------------------------------------------------"; + std::vector best_candidates; + while (curr_state != init_state) { + ICHECK(curr_state->best_candidate_.defined()); + VLOG(1) << "Best candidate " << curr_state->best_candidate_->ToString(); + best_candidates.emplace_back(curr_state->best_candidate_); + curr_state = curr_state->pred_state_; + ICHECK(curr_state != nullptr); + } + VLOG(1) << "----------------------------------------------------------------------"; + return Finalize(expr, best_candidates); + } + + size_t num_fires = 0; + Expr sub_expr = dataflow_graph_->index_to_node(next_index)->ref(); + VLOG(1) << "Looking at index " << next_index << " for sub-expression " + << SubExprKindAndLabel(sub_expr).second; + + // Explore all the outgoing candidates from the current state. + for (const auto& candidate : index.candidates_at(next_index)) { + VLOG(1) << "Considering candidate " << candidate->ToString() << " (" << ++num_candidates + << ")"; + if (!candidate->sub_graph_->inside_.AreDisjoint(curr_state->covered_)) { + VLOG(1) << "Candidate overlaps with already fused nodes"; + continue; + } + IndexSet next_covered = curr_state->covered_ | candidate->sub_graph_->inside_; + SearchState* next_state = GetState(next_covered); + Relax(curr_state, next_state, candidate); + ++num_fires; + } + ICHECK_GT(num_fires, 0) + << "No candidate was found covering sub-expression at index " << next_index + << ", suggesting the partition rules are incomplete for the given targets."; + } + ICHECK(false) << "should have reached end state in which all sub-expressions are covered"; + return {}; + } + + /*! \brief Returns the unique state corresponding to the \p covered sub-graph. */ + SearchState* GetState(const IndexSet& covered) { + auto itr = covered_to_state_.find(covered); + if (itr != covered_to_state_.end()) { + return itr->second.get(); + } + auto state = std::make_unique(covered); + SearchState* raw_ptr = state.get(); + covered_to_state_.emplace(covered, std::move(state)); + return raw_ptr; + } + + /*! + * \brief Record that it is possible to reach \p next_state by choosing \p candidate + * in \p curr_state. If the resulting cost is better than the best known so far, update + * \p next_state's best cost, predecessor and candidate to match. + */ + void Relax(SearchState* curr_state, SearchState* next_state, + const CandidatePartition& candidate) { + Cost candidate_cost = + candidate->EstimatedCost(*dataflow_graph_, &cost_estimator_, *name_supply_); + Cost new_state_cost = candidate_cost + curr_state->best_cost_; + const bool is_new = next_state->best_cost_.is_invalid(); + CandidatePartition previously_best_candidate = next_state->best_candidate_; + if (is_new || new_state_cost < next_state->best_cost_) { + next_state->pred_state_ = curr_state; + Cost previously_best_cost = next_state->best_cost_; + next_state->best_cost_ = new_state_cost; + next_state->best_candidate_ = candidate; + if (is_new) { + VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString() + << " (New state for spec " << candidate->partition_spec_name() << ")"; + pq_.Push(next_state); + } else { + VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString() + << " (Spec " << candidate->partition_spec_name() << " beats previous spec " + << previously_best_candidate->partition_spec_name() << " by " + << (previously_best_cost - curr_state->best_cost_).ToString() << ")"; + pq_.Update(next_state); + } + } else { + VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString() + << " (Spec " << candidate->partition_spec_name() << " does not beat existing spec " + << previously_best_candidate->partition_spec_name() << ")"; + } + } + + /*! + * \brief Returns the result of partitioning \p expr according to 'optimal' candidates found + * by the search. + */ + Expr Finalize(const Expr& expr, std::vector best_candidates) { + best_candidates = CandidatePartition::MaxCoalesce(*dataflow_graph_, best_candidates); + return CandidatePartition::ParallelRewrite(*dataflow_graph_, expr, best_candidates, + *name_supply_); + } + + private: + /*! \brief Available partition specs to use during search. */ + Array partition_specs_; + /*! + * \brief The virtual devices for every sub-expression so we can respect any existing target + * constraints. + */ + const std::unordered_map* virtual_devices_; + /*! \brief Dataflow graph for overall expression. */ + std::unique_ptr dataflow_graph_; + /*! \brief How to generate globally unique and compiler-friendly names. */ + std::unique_ptr name_supply_; + /*! \brief Cost estimator to use for candidates. */ + CostEstimator cost_estimator_; + /*! \brief Map from covered sub-graphs to the corresponding state. */ + std::unordered_map, IndexSetHash, IndexSetEqual> + covered_to_state_; + /*! \brief Priority queue of states, ordered by increasing cost. */ + PriorityQueue pq_; +}; + +} // namespace + +transform::Pass CollagePartition(CompilationConfig config) { + auto pass_func = [=](IRModule mod, transform::PassContext ctxt) { + Optional opt_enable = ctxt->GetConfig("relay.collage.enable_collage", Bool(false)); + if (!opt_enable.value()) { + VLOG(1) << "ignoring since collage is disabled"; + return mod; + } + Array partition_specs = GatherPartitionSpecs(config); + VLOG(1) << "Gathered " << partition_specs.size() << " partition specs"; + IRModule out_mod = mod->ShallowCopy(); + for (const auto& kv : mod->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + auto function = GetRef(function_node); + VLOG(1) << "Partitioning " << kv.first->name_hint << " from:\n" << PrettyPrint(function); + std::unordered_map virtual_devices = + RecoverVirtualDeviceMap(mod, function); + Partitioner partitioner(partition_specs, &virtual_devices); + Function result = Downcast(partitioner.Partition(function)); + VLOG(1) << "Partitioned " << kv.first->name_hint << " to:\n" << PrettyPrint(result); + out_mod->Add(kv.first, result); + } + } + + // Establish the tuning log for the rest of the compilation flow + // TODO(mbs): This is pretty gross. + static const runtime::PackedFunc* establish_autotvm_logs = + runtime::Registry::Get("tvm.relay.collage.establish_autotvm_logs"); + ICHECK(establish_autotvm_logs); + (*establish_autotvm_logs)(); + + return out_mod; + }; + return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/0, "CollagePartition", {}); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/collage_partitioner.h b/src/relay/collage/collage_partitioner.h new file mode 100644 index 0000000000000..38a95de65bad2 --- /dev/null +++ b/src/relay/collage/collage_partitioner.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/collage/collage_partitioner.h + * \brief Search for an optimal partitioning of a Relay model. + * + * See: + * Collage: Automated Integration of Deep Learning Backends + * Byungsoo Jeon, Sunghyun Park, Peiyuan Liao, Sheng Xu, Tianqi Chen, Zhihao Jia + * https://arxiv.org/pdf/2111.00655.pdf + */ +#ifndef TVM_RELAY_COLLAGE_COLLAGE_PARTITIONER_H_ +#define TVM_RELAY_COLLAGE_COLLAGE_PARTITIONER_H_ + +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Explores the space of all possible (sub-graph, target) pairs which cover the + * model, and applies the globally optimal choice (assuming partition costs are additive). + */ +transform::Pass CollagePartition(CompilationConfig config); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_COLLAGE_PARTITIONER_H_ diff --git a/src/relay/collage/combiner_rule.cc b/src/relay/collage/combiner_rule.cc new file mode 100644 index 0000000000000..73a1da78e0d86 --- /dev/null +++ b/src/relay/collage/combiner_rule.cc @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/combiner_rule.cc + * \brief Helpers for the \p CombinePartitionRule + */ + +#include "./combiner_rule.h" + +#include "./partition_spec.h" + +namespace tvm { +namespace relay { +namespace collage { + +bool SimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph, + const CandidatePartition& upstream, + const CandidatePartition& downstream) const { + return false; +} + +std::string SimpleCombinerRuleNode::ToString() const { + return "SimpleCombinerRule(" + rule_name_ + ")"; +} + +SimpleCombinerRule::SimpleCombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +bool ByKindSimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph, + const CandidatePartition& upstream, + const CandidatePartition& downstream) const { + return upstream->sub_graph_->kind_ <= upstream_kind_ && + downstream->sub_graph_->kind_ <= downstream_kind_; +} + +std::string ByKindSimpleCombinerRuleNode::ToString() const { + std::ostringstream os; + os << "ByKindSimpleCombinerRule(" << rule_name_ << ")"; + return os.str(); +} + +ByKindSimpleCombinerRule::ByKindSimpleCombinerRule(OpPatternKind upstream_kind, + OpPatternKind downstream_kind) { + auto node = runtime::make_object(); + String rule_name = KindToString(upstream_kind) + "->" + KindToString(downstream_kind); + node->rule_name_ = std::move(rule_name); + node->upstream_kind_ = upstream_kind; + node->downstream_kind_ = downstream_kind; + data_ = std::move(node); +} + +void CombinerRuleNode::AppendAllResults(AppendAllResultsContext& ctxt) const {} + +std::string CombinerRuleNode::ToString() const { return "CombinerRuleNode(" + rule_name_ + ")"; } + +CombinerRule::CombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +void AllSimpleCombinerRuleNode::AppendAllResults(AppendAllResultsContext& ctxt) const { + VLOG(1) << "running AllSimpleCombinerRule(" << rule_name_ << ")"; + // Build map from post-dfs indices to the indices of candidates with corresponding entry node. + // NOTE: the index set is over candidate indices not post-dfs indices! + std::vector entry_map(ctxt.dataflow_graph->size(), + IndexSet(ctxt.candidate_set->size())); + for (size_t i = 0; i < ctxt.candidate_set->size(); ++i) { + CandidatePartition candidate = ctxt.candidate_set->at(i); + for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) { + entry_map[entry_index].Add(i); + } + } + + for (size_t i = 0; i < ctxt.candidate_set->size(); ++i) { + CandidatePartition upstream = ctxt.candidate_set->at(i); + // Narrow our search to just those candidates which could touch. + IndexSet possible_downstream(ctxt.candidate_set->size()); + for (PostDfsIndex output_index : upstream->sub_graph_->output_) { + possible_downstream = possible_downstream | entry_map[output_index]; + } + size_t start_j = + i < ctxt.candidate_set->first_new_index ? ctxt.candidate_set->first_new_index : 0; + for (size_t j : possible_downstream) { + if (i == j) { + continue; + } + if (i < start_j) { + // We already explored the cross-product of candidates [0, first_new_index), so don't + // do it again. + continue; + } + // Note that the rules are not commutative so we can't just ignore if j < i. + CandidatePartition downstream = ctxt.candidate_set->at(j); + if (ctxt.max_max_depth > 0 && + upstream->sub_graph_->max_depth_ + downstream->sub_graph_->max_depth_ > + ctxt.max_max_depth) { + continue; + } + if (!upstream.AreTouching(*ctxt.dataflow_graph, downstream)) { + continue; + } + for (const auto& simple_rule : simple_rules_) { + if (simple_rule->Fires(*ctxt.dataflow_graph, upstream, downstream)) { + CandidatePartition new_candidate = + upstream.DisjointUnion(*ctxt.dataflow_graph, downstream); + VLOG(2) << "Fired " << simple_rule->rule_name_ << " on upstream candidate " + << upstream->ToString() << " and downstream candidate " << downstream->ToString() + << " to yield " << new_candidate->ToString(); + ctxt.candidate_set->Add(*ctxt.dataflow_graph, new_candidate); + } + } + } + } +} + +std::string AllSimpleCombinerRuleNode::ToString() const { + std::ostringstream os; + os << "AllSimpleCombinerRule(" << rule_name_; + for (const auto& simple : simple_rules_) { + os << ", " << simple->ToString(); + } + os << ")"; + return os.str(); +} + +AllSimpleCombinerRule::AllSimpleCombinerRule(String rule_name, + Array simple_rules) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->simple_rules_ = std::move(simple_rules); + data_ = std::move(node); +} + +void TupleArgCombinerRuleNode::AppendAllResults(AppendAllResultsContext& ctxt) const { + VLOG(1) << "running TupleArgCombinerRule(" << rule_name_ << ")"; + // Build map from post-dfs index to the indices of injective candidates with corresponding entry + // node. NOTE: the index set is over candidate indices not post-dfs indices! + std::vector exit_map(ctxt.dataflow_graph->size(), IndexSet(ctxt.candidate_set->size())); + for (size_t i = 0; i < ctxt.candidate_set->size(); ++i) { + CandidatePartition candidate = ctxt.candidate_set->at(i); + if (candidate->sub_graph_->kind_ > kInjective) { + continue; + } + for (PostDfsIndex exit_index : candidate->sub_graph_->exit_) { + exit_map[exit_index].Add(i); + } + } + + // The two-step I -> tuple -> I rule. + // Look all possible tuple consumers... + for (size_t i = 0; i < ctxt.candidate_set->size(); ++i) { + CandidatePartition tuple_consumer_candidate = ctxt.candidate_set->at(i); + if (tuple_consumer_candidate->sub_graph_->kind_ > kInjective) { + continue; + } + // For all possible tuples feeding into candidate... + for (PostDfsIndex input_index : tuple_consumer_candidate->sub_graph_->input_) { + auto node = ctxt.dataflow_graph->index_to_node(input_index); + Expr sub_expr = node->ref(); + const auto* tuple_node = sub_expr.as(); + if (tuple_node == nullptr) { + continue; + } + // The tuple_consumer_candidate candidate consumes (at least one) tuple, eg as an argument + // to an operator. + // eg: concatenate((field1, ..., fieldn)) + auto tuple_dataflow_node = ctxt.dataflow_graph->item_to_node(tuple_node); + + // Collect all the possible unions. There may be more than one if different candidates + // could supply the same tuple field. + std::vector> all_possible_unions; + + // Obviously we must include the consumer. + all_possible_unions.emplace_back(); + all_possible_unions.back().emplace_back(tuple_consumer_candidate); + + // We must include the tuple itself. + SubGraph tuple_sub_graph(*ctxt.dataflow_graph, + IndexSet(ctxt.dataflow_graph->size(), {node->index_}), kInjective, + "tuple"); + CandidatePartition tuple_candidate("", std::move(tuple_sub_graph), + tuple_consumer_candidate->partition_spec()); + all_possible_unions.back().emplace_back(std::move(tuple_candidate)); + + // For all tuple fields... + bool all_tuple_fields_have_producer = true; + for (auto* tuple_field_dataflow_node : tuple_dataflow_node->inputs_) { + // Collect all the candidates which could produce this tuple field. + std::vector to_appends; + size_t start_j = + i < ctxt.candidate_set->first_new_index ? ctxt.candidate_set->first_new_index : 0; + for (size_t j : exit_map[tuple_field_dataflow_node->index_]) { + if (i == j) { + continue; + } + if (i < start_j) { + // We already explored the cross-product of candidates [0, first_new_index), so don't + // do it again. + continue; + } + CandidatePartition tuple_field_producer = ctxt.candidate_set->at(j); + // The tuple_field_producer candidate can provide this tuple field. + // eg concatenate((..., producer, ...)) + to_appends.emplace_back(tuple_field_producer); + } + if (to_appends.empty()) { + // At least one of the tuple's fields does not have a producer candidate we can + // union in, so we need to give up. + all_tuple_fields_have_producer = false; + break; + } else { + // If to_appends = [A, B] and we already have possible unions [C, D] and [E, F] then + // the new possible unions are [C, D, A], [C, D, B], [E, F, A] and [E, F, B]. + std::vector> new_all_possible_unions; + for (const auto& to_append : to_appends) { + for (const auto& possible_union : all_possible_unions) { + new_all_possible_unions.emplace_back(possible_union); + new_all_possible_unions.back().emplace_back(to_append); + } + } + all_possible_unions = std::move(new_all_possible_unions); + } + } + + if (!all_tuple_fields_have_producer) { + continue; + } + + // Actually build the candidates which union according to all_possible_unions. + for (const auto& possible_union : all_possible_unions) { + if (possible_union.size() > 2) { + CandidatePartition new_candidate = + CandidatePartition::DisjointUnion(*ctxt.dataflow_graph, possible_union); +#if TVM_LOG_DEBUG + std::ostringstream os; + bool first = true; + for (const auto& candidate : possible_union) { + if (first) { + first = false; + } else { + os << ", "; + } + os << candidate->ToString(); + } + VLOG(2) << "Fired rule " << rule_name_ << " on {" << os.str() << "} to yield " + << new_candidate->ToString(); +#endif + ctxt.candidate_set->Add(*ctxt.dataflow_graph, new_candidate); + } + } + } + } +} + +std::string TupleArgCombinerRuleNode::ToString() const { + return "TupleArgCombinerRule(" + rule_name_ + ")"; +} + +TupleArgCombinerRule::TupleArgCombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +void ConstantCombinerRuleNode::AppendAllResults(AppendAllResultsContext& ctxt) const { + VLOG(1) << "running ConstantCombinerRule(" << rule_name_ << ")"; + // We already explored [0, first_new_index), so don't do it again. + for (size_t i = ctxt.candidate_set->first_new_index; i < ctxt.candidate_set->size(); ++i) { + CandidatePartition base = ctxt.candidate_set->at(i); + IndexSet new_constants(ctxt.dataflow_graph->size()); + for (PostDfsIndex index : base->sub_graph_->input_) { + auto node = ctxt.dataflow_graph->index_to_node(index); + if (node->ref().as()) { + new_constants.Add(index); + } + } + if (!new_constants.IsZero()) { + SubGraph sub_graph(*ctxt.dataflow_graph, new_constants, kElemWise, "const"); + CandidatePartition new_const_candidate("", std::move(sub_graph), base->spec_); + CandidatePartition new_candidate = + base.DisjointUnion(*ctxt.dataflow_graph, new_const_candidate); + VLOG(2) << "Fired rule " << rule_name_ << " on " << new_const_candidate->ToString() << " and " + << base->ToString() << " to yield " << new_candidate->ToString(); + ctxt.candidate_set->Add(*ctxt.dataflow_graph, new_candidate); + } + } +} + +std::string ConstantCombinerRuleNode::ToString() const { + return "ConstantCombinerRule(" + rule_name_ + ")"; +} + +ConstantCombinerRule::ConstantCombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/combiner_rule.h b/src/relay/collage/combiner_rule.h new file mode 100644 index 0000000000000..75e8e9d526095 --- /dev/null +++ b/src/relay/collage/combiner_rule.h @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/combiner_rule.h + * \brief Helpers for the \p CombinePartitionRule + */ + +#ifndef SRC_RELAY_COLLAGE_COMBINER_RULE_H_ +#define SRC_RELAY_COLLAGE_COMBINER_RULE_H_ + +#include +#include + +#include "./candidate_partition.h" +#include "./candidate_set.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Base class for all 'simple' combiner rules. + * + * Given \p upstream and \p downstream candidates which touch, a simple combiner rule returns + * true if their union should also be considered a candidate. + */ +class SimpleCombinerRuleNode : public Object { + public: + String rule_name_; + + virtual bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream, + const CandidatePartition& downstream) const; + + virtual std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.SimpleCombinerRule"; + static constexpr const uint32_t _type_child_slots = 1; + TVM_DECLARE_BASE_OBJECT_INFO(SimpleCombinerRuleNode, Object); +}; + +class SimpleCombinerRule : public ObjectRef { + public: + explicit SimpleCombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(SimpleCombinerRule, ObjectRef, SimpleCombinerRuleNode); +}; + +/*! + * \brief A simple combiner rule which fires if the \p upstream and \p downstream candidates have + * the given \p upstream_kind and \p downstream_kind (or less) respectively. + */ +class ByKindSimpleCombinerRuleNode : public SimpleCombinerRuleNode { + public: + OpPatternKind upstream_kind_; + OpPatternKind downstream_kind_; + + bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream, + const CandidatePartition& downstream) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.ByKindSimpleCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(ByKindSimpleCombinerRuleNode, SimpleCombinerRuleNode); +}; + +class ByKindSimpleCombinerRule : public SimpleCombinerRule { + public: + ByKindSimpleCombinerRule(OpPatternKind upstream_kind, OpPatternKind downstream_kind); + + TVM_DEFINE_OBJECT_REF_METHODS(ByKindSimpleCombinerRule, SimpleCombinerRule, + ByKindSimpleCombinerRuleNode); +}; + +/*! \brief Context required by CombineRuleNode::AppendAllResultsContext. */ +struct AppendAllResultsContext { + AppendAllResultsContext(const DataflowGraph* dataflow_graph, size_t max_max_depth, + CandidateSet* candidate_set) + : dataflow_graph(dataflow_graph), + max_max_depth(max_max_depth), + candidate_set(candidate_set) {} + + const DataflowGraph* dataflow_graph; + size_t max_max_depth; + CandidateSet* candidate_set; +}; + +/*! + * \brief Base class for all 'combiner' rules. + * + * Given the current candidate set, a combiner rule looks for opportunities to form larger + * candidates, optionally removing existing candidates in the process. + * + * Currently only \p AllSimpleCombinerRule, \p TupleArgCombinerRule and \p ConstantCombinerRule + * are implemented. + */ +class CombinerRuleNode : public Object { + public: + String rule_name_; + + virtual void AppendAllResults(AppendAllResultsContext& ctxt) const; + virtual std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.CombinerRule"; + static constexpr const uint32_t _type_child_slots = 3; + TVM_DECLARE_BASE_OBJECT_INFO(CombinerRuleNode, Object); +}; + +class CombinerRule : public ObjectRef { + public: + explicit CombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(CombinerRule, ObjectRef, CombinerRuleNode); +}; + +/*! + * \brief A combiner rule which runs one or more simple combiner rules over the current + * touching candidates. + */ +class AllSimpleCombinerRuleNode : public CombinerRuleNode { + public: + Array simple_rules_; + + void AppendAllResults(AppendAllResultsContext& ctxt) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.AllSimpleCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllSimpleCombinerRuleNode, CombinerRuleNode); +}; + +class AllSimpleCombinerRule : public CombinerRule { + public: + AllSimpleCombinerRule(String rule_name, Array simple_rules); + + TVM_DEFINE_OBJECT_REF_METHODS(AllSimpleCombinerRule, CombinerRule, AllSimpleCombinerRuleNode); +}; + +/*! + * \brief A combiner rule which combines injective sub-groups which appear inside tuples which are + * themselves inputs to injective sub-groups. + */ +class TupleArgCombinerRuleNode : public CombinerRuleNode { + public: + void AppendAllResults(AppendAllResultsContext& ctxt) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.TupleArgCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleArgCombinerRuleNode, CombinerRuleNode); +}; + +class TupleArgCombinerRule : public CombinerRule { + public: + explicit TupleArgCombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleArgCombinerRule, CombinerRule, TupleArgCombinerRuleNode); +}; + +/*! + * \brief A combiner rule which combines constants in argument positions to existing candidates. + * Note that scalars are always inlined, so this rule only combines tensor constant arguments. + */ +class ConstantCombinerRuleNode : public CombinerRuleNode { + public: + void AppendAllResults(AppendAllResultsContext& ctxt) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.ConstantCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantCombinerRuleNode, CombinerRuleNode); +}; + +class ConstantCombinerRule : public CombinerRule { + public: + explicit ConstantCombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(ConstantCombinerRule, CombinerRule, ConstantCombinerRuleNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_COMBINER_RULE_H_ diff --git a/src/relay/collage/cost.cc b/src/relay/collage/cost.cc new file mode 100644 index 0000000000000..b2ee5b1197a0b --- /dev/null +++ b/src/relay/collage/cost.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost.cc + * \brief Represents the estimated cost of a candidate partition. + */ + +#include "./cost.h" + +namespace tvm { +namespace relay { +namespace collage { + +std::string Cost::ToString() const { + if (is_invalid()) { + return "invalid"; + } else if (is_unknown()) { + return "unknown"; + } else if (value_ == 0.0) { + return "0"; + } else { + return std::to_string(value_ * 1e6) + "us"; + } +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/cost.h b/src/relay/collage/cost.h new file mode 100644 index 0000000000000..8ae276d22078f --- /dev/null +++ b/src/relay/collage/cost.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost.h + * \brief Represents the estimated cost of a candidate partition. + */ +#ifndef TVM_RELAY_COLLAGE_COST_H_ +#define TVM_RELAY_COLLAGE_COST_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief The assumed cost for a candidate partition. Generally average execution time in seconds. + * However other cost functions are possible, for example to introduce a penalty for high memory + * use, etc. + */ +class Cost { + public: + Cost() = delete; + + static Cost Zero() { return Cost(0.0); } + + /*! + * \brief Returns the distinguished 'invalid' cost signaling a candidate partition is not + * supported by the intended target, for example because the sub-graph has an unsupported operator + * or the intermediate memory required exceeds some system limit. + */ + static Cost Invalid() { return Cost(std::numeric_limits::infinity()); } + + bool is_invalid() const { return std::isinf(value_) && value_ > 0.0; } + + /*! + * \brief Returns the distinguished 'unknown' cost, signaling fixed priorities should be used to + * choose the best partitions. This can be used to disable tuning and fallback to fixed rules, + * much as TVM will use an un-tuned kernel if no tuning records are available. + */ + static Cost Unknown() { return Cost(std::numeric_limits::quiet_NaN()); } + + bool is_unknown() const { return std::isnan(value_); } + + /*! \brief Returns cost with given finite, non-negative value. */ + static Cost Value(double value) { + ICHECK(!std::isnan(value) && !std::isinf(value) && value >= 0.0); + return Cost(value); + } + + bool is_value() const { return !std::isnan(value_) && !std::isinf(value_); } + + /*! \brief Return true if the less-than relation is defined for this and that. */ + bool are_comparable(Cost that) const { return !std::isnan(value_) && !std::isnan(that.value_); } + + /*! \brief Returns sum of this and that. */ + Cost operator+(Cost that) const { return Cost(value_ + that.value_); } + + /*! \brief Returns difference of this and that. */ + Cost operator-(Cost that) const { return Cost(value_ - that.value_); } + + /*! \brief Returns true if this is cheaper than that, assuming they are comparable. */ + bool operator<(Cost that) const { return value_ < that.value_; } + + std::string ToString() const; + + private: + explicit Cost(double value) : value_(value) {} + + /*! + * \brief Non-negative value or: + * - +inf if candidate partition is not feasible. + * - NaN if candidate partition has an unknown cost (priority may be used to break ties). + */ + double value_ = 0.0; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_COST_H_ diff --git a/src/relay/collage/cost_estimator.cc b/src/relay/collage/cost_estimator.cc new file mode 100644 index 0000000000000..4497ea33e5cad --- /dev/null +++ b/src/relay/collage/cost_estimator.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost_estimator.cc + * \brief Interface for measuring candidate partition cost. + */ + +#include "./cost_estimator.h" + +namespace tvm { +namespace relay { +namespace collage { + +Cost CostEstimator::Estimate(const IRModule& mod, const Target& target) const { + static const runtime::PackedFunc* estimate_seconds = + runtime::Registry::Get("tvm.relay.collage.estimate_seconds"); + ICHECK(estimate_seconds); + const double value = (*estimate_seconds)(mod, target); + if (std::isinf(value)) { + return Cost::Invalid(); + } else if (std::isnan(value)) { + return Cost::Unknown(); + } else { + return Cost::Value(value); + } +} + +Cost CostEstimator::CachedEstimate(const IRModule& mod, const Target& target) { + std::string key = CacheKey(mod, target); + auto itr = cache_.find(key); + if (itr != cache_.end()) { + VLOG(1) << "Reusing cost cached in estimator"; + return itr->second; + } + Cost cost = Estimate(mod, target); + cache_.emplace(key, cost); + return cost; +} + +std::string CostEstimator::CacheKey(const IRModule& mod, const Target& target) { + std::ostringstream os; + os << "{"; + os << PrettyPrint(mod); + os << "}{"; + os << target->ToDebugString(); + os << "}"; + return os.str(); +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/cost_estimator.h b/src/relay/collage/cost_estimator.h new file mode 100644 index 0000000000000..c15a356a47f96 --- /dev/null +++ b/src/relay/collage/cost_estimator.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost_estimator.cc + * \brief Interface for measuring candidate partition cost. + */ + +#ifndef TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_ +#define TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_ + +#include + +#include "./cost.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief An (abstract) estimator for the cost of executing "main" in an \p IRModule representing + * a candidate partition, using the given target for lowering and codegen. + * + * Generally the implementation will compile to a \p runtime::Module (possibly on a target-specific + * worker if cross-compilation is not available), repeatedly invoke "main" with random data until + * measure variance is acceptable (on a target-specific worker), and return the summarized costs. + * The result may be cached, however the cache lookup and update is hidden. + * + * If using a TVM native \p Target, it is possible compilation will itself invoke TVM tuning. + * + * TODO(mbs): Actually, currently not abstract so can get some local measurements. + */ +class CostEstimator { + public: + /*! + * \brief Returns the estimated cost (possibly after many many minutes of training time) of + * running "main" in \p mod using \p target, which represents a possible partitioning of + * some overall Relay expression. + */ + virtual Cost Estimate(const IRModule& mod, const Target& target) const; + + /*! + * \brief As for \p Estimate, but use and update the internal in-memory cache. + */ + Cost CachedEstimate(const IRModule& mod, const Target& target); + + private: + /*! \brief Returns string which is 1:1 with \p mod and \p target. */ + std::string CacheKey(const IRModule& mod, const Target& target); + + /*! + * \brief In-memory cache. + * + * TODO(mbs): This is just to get us going. + */ + std::unordered_map cache_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_ diff --git a/src/relay/collage/dataflow_graph.cc b/src/relay/collage/dataflow_graph.cc new file mode 100644 index 0000000000000..787f310b59f27 --- /dev/null +++ b/src/relay/collage/dataflow_graph.cc @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/dataflow_graph.cc + * \brief A representation of the dataflow for an overall Relay expression. + */ + +#include "./dataflow_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +DataflowGraph::DataflowGraph(Expr expr) : expr_(std::move(expr)) { + indexed_graph_ = CreateIndexedGraph(expr_); + downstream_map_.reserve(indexed_graph_->size()); + for (PostDfsIndex index = 0; index < indexed_graph_->size(); ++index) { + const Node* node = indexed_graph_->index_to_node(index); + std::unordered_set downstream_nodes; + node->AccumulateDownstreamNodes(downstream_nodes); + IndexSet index_set(indexed_graph_->size()); + for (const Node* downstream_node : downstream_nodes) { + index_set.Add(downstream_node->index_); + } + downstream_map_.emplace_back(std::move(index_set)); + } +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/dataflow_graph.h b/src/relay/collage/dataflow_graph.h new file mode 100644 index 0000000000000..249ccaef74707 --- /dev/null +++ b/src/relay/collage/dataflow_graph.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/dataflow_graph.h + * \brief A representation of the dataflow for an overall Relay expression. + */ +#ifndef TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_ +#define TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_ + +#include + +#include "../ir/indexed_graph.h" +#include "index_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Represents the dataflow of an overall Relay expression. + */ +class DataflowGraph { + public: + using Node = IndexedGraph::Node; + + explicit DataflowGraph(Expr expr); + + size_t size() const { return indexed_graph_->size(); } + const Node* index_to_node(PostDfsIndex index) const { + return indexed_graph_->index_to_node(index); + } + const Node* item_to_node(const Expr& expr) const { return indexed_graph_->item_to_node(expr); } + const Node* item_to_node(const ExprNode* expr_node) const { + return indexed_graph_->item_to_node(expr_node); + } + const IndexedGraph& indexed_graph() const { return *indexed_graph_; } + + const IndexSet& downstream_of(PostDfsIndex index) const { + ICHECK_LT(index, indexed_graph_->size()); + return downstream_map_[index]; + } + + private: + /*! \brief The overall expression. */ + Expr expr_; + /*! \brief The indexed graph which captures the main dataflow. */ + std::unique_ptr> indexed_graph_; + /*! \brief Map from a node's PostDfsIndex to the set of it's downstream dataflow node indexes. */ + std::vector downstream_map_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_ diff --git a/src/relay/collage/gather_partition_specs.cc b/src/relay/collage/gather_partition_specs.cc new file mode 100644 index 0000000000000..ade38aec5a37a --- /dev/null +++ b/src/relay/collage/gather_partition_specs.cc @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/gather_partition_specs.cc + * \brief Gather the relevant \p PartitionSpecs from the available \p Targets. + */ + +#include "./gather_partition_specs.h" + +namespace tvm { +namespace relay { +namespace collage { + +namespace { + +constexpr size_t kTVMMaxMaxDepth = 4; + +PartitionRule MakeCombinePartitionRule(PartitionRule sub_rule, Array combiner_rules, + size_t max_max_depth) { + if (combiner_rules.empty()) { + return sub_rule; + } else { + return CombinePartitionRule("", std::move(sub_rule), std::move(combiner_rules), max_max_depth); + } +} + +/*! \brief Returns the primitive combiner rules which mimic \p FuseOps. */ +Array TVMCombinerRules() { + Array simple_rules; + // Mimic the FuseOps rules. + simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast)); + simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce)); + simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective)); + + Array combiner_rules; + // Fire the simple fusion rules + combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules))); + // Fuse tuple arguments + combiner_rules.push_back(TupleArgCombinerRule("tuple")); + + return combiner_rules; +} + +/*! \brief Returns partition rule mimicking TVM FuseOps. */ +PartitionRule MakeTVMPartitionRule() { + // Build singleton candidates for all calls to ops <= kOutEWiseFusable. + OpCallByKindPartitionRule op_call_by_kind(""); + // Combine candidates according to the TVM fusion rules. + PartitionRule combine = + MakeCombinePartitionRule(std::move(op_call_by_kind), TVMCombinerRules(), kTVMMaxMaxDepth); + // Discard invalid candidates. + SubGraphConfig sub_graph_config; + sub_graph_config.allow_taps = false; + sub_graph_config.max_max_depth = kTVMMaxMaxDepth; + sub_graph_config.max_exits = 1; + return OnlyValidPartitionRule("", std::move(combine), sub_graph_config); + // NOTE: We don't wrap by a "Primitive" since we want to defer making TVM fusion decisions until + // after running more Relay passes. +} + +constexpr size_t kBYOCMaxMaxDepth = 4; + +/*! + * \brief Returns the fusion style for \p compiler. + * + * TODO(mbs): Defer to per-BYOC integration definition. + */ +BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) { + if (compiler == "cutlass" || compiler == "cublas" || compiler == "cudnn") { + return kNoFusionBYOCStyle; + } else if (compiler == "tensorrt") { + return kTVMFusionBYOCStyle; + } else { + return kArbitraryFusionBYOCStyle; + } +} + +/*! + * \brief Returns true if BYOC patterns do not require "Composite" functions. + * + * TODO(mbs): Defer to per-BYOC integration definition. + */ +bool BYOCNoCompositeFunctionsForCompiler(const String& compiler) { return compiler == "tensorrt"; } + +/*! + * \brief Returns the primitive combiner rules which allow for any touching candidates + * to be fused provided they don't have kind \p kOpaque. + */ +Array BYOCCombinerRules(const String& compiler) { + Array simple_rules; + Array combiner_rules; + switch (BYOCFusionStyleForCompiler(compiler)) { + case kNoFusionBYOCStyle: + break; + case kTVMFusionBYOCStyle: + // Conservatively assume the BYOC toolchain follows the same rules as for TVM's FuseOps. + simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast)); + simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce)); + simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective)); + combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules))); + break; + case kArbitraryFusionBYOCStyle: + // Just try all combinations up to the max_max_depth limit. + simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kOutEWiseFusable)); + combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules))); + break; + } + return combiner_rules; +} + +/*! \brief Returns partition rule mimicking one entry in the patterns list passed to the + * MergeComposite pass. */ +PartitionRule MakeLabelledDFPatternPartitionRule( + const std::string& compiler, String rule_name, DFPattern dataflow_pattern, + TPatternPredicate predicate = DefaultPatternPredicate) { + if (BYOCNoCompositeFunctionsForCompiler(compiler)) { + return DFPatternPartitionRule(std::move(rule_name), std::move(dataflow_pattern), + std::move(predicate)); + } else { + return CompositePartitionRule( + std::move(rule_name), + DFPatternPartitionRule("", std::move(dataflow_pattern), std::move(predicate))); + } +} + +/*! + * \brief Returns partition rule mimicking + * MergeComposite/AnnotateTarget/MergeCompilerRegions/PartitionGraph passes for "compiler" + * attribute of \p target. + */ +PartitionRule MakePatternBYOCPartitionRule(const std::string& compiler, + Array sub_rules) { + // Union all the individual pattern rules. + UnionPartitionRule unioned("", std::move(sub_rules)); + PartitionRule combine = + MakeCombinePartitionRule(std::move(unioned), BYOCCombinerRules(compiler), kBYOCMaxMaxDepth); + // Ignore invalid candidates. + SubGraphConfig sub_graph_config; + sub_graph_config.allow_taps = false; + sub_graph_config.max_max_depth = kBYOCMaxMaxDepth; + sub_graph_config.max_exits = 1; + OnlyValidPartitionRule valid("", std::move(combine), sub_graph_config); + // Wrap the candidates in a "Primitive" function with a "Compiler" attribute. + return PrimitivePartitionRule("", std::move(valid)); +} + +TVM_REGISTER_GLOBAL("relay.collage.make_labelled_dfpattern_partition_rule") + .set_body_typed([](String compiler, String rule_name, DFPattern dataflow_pattern) { + return MakeLabelledDFPatternPartitionRule(std::move(compiler), std::move(rule_name), + std::move(dataflow_pattern)); + }); + +TVM_REGISTER_GLOBAL("relay.collage.make_labelled_dfpattern_partition_rule_with_predicate") + .set_body_typed([](String compiler, String rule_name, DFPattern dataflow_pattern, + TPatternPredicate predicate) { + return MakeLabelledDFPatternPartitionRule(std::move(compiler), std::move(rule_name), + std::move(dataflow_pattern), std::move(predicate)); + }); + +TVM_REGISTER_GLOBAL("relay.collage.make_pattern_byoc_partition_rule") + .set_body_typed(MakePatternBYOCPartitionRule); + +/*! + * \brief Returns the rule to pick out expression nodes which can be 'left behind' for execution + * on the host. + */ +PartitionRule MakeHostPartitionRule() { return HostPartitionRule(""); } + +} // namespace + +Array GatherPartitionSpecs(const CompilationConfig& config) { + // First collect the partition rules by 'toolchain' (ie BYOC compiler name or the native "tvm"). + // We'll assume but not verify rules derived from targets are uniquely determined by the + // toolchain name. + std::unordered_map toolchain_to_rule; + std::unordered_map> toolchain_to_targets; + + // Accumulate all the partition rules for the primitive targets. + for (const auto& primitive_target : config->primitive_targets) { + Optional opt_compiler = primitive_target->GetAttr("compiler", Optional()); + std::string spec_name = opt_compiler.defined() ? opt_compiler.value() : kTVMSpecName; + auto itr = toolchain_to_rule.find(spec_name); + if (itr != toolchain_to_rule.end()) { + // Already constructed. + continue; + } + PartitionRule rule; + Optional opt_rule = + primitive_target->GetAttr("partition_rule", Optional()); + if (opt_rule) { + rule = opt_rule.value(); + VLOG(1) << "Target " << primitive_target->ToDebugString() << " has spec_name " << spec_name + << " and explicit 'partition_rule' attribute:\n" + << rule->ToString(); + } else if (opt_compiler.defined()) { + // Transition to the Python side so we can get access to the BYOC pattern registry. + // That will bounce right back into the above construction helpers. + static const runtime::PackedFunc* make_byoc_partition_rule = + runtime::Registry::Get("tvm.relay.collage.make_byoc_partition_rule"); + ICHECK(make_byoc_partition_rule); + rule = (*make_byoc_partition_rule)(opt_compiler.value()); + VLOG(1) << "Target " << primitive_target->ToDebugString() << " is for BYOC spec_name " + << spec_name << " and has default partition rule:\n" + << rule->ToString(); + } else { + rule = MakeTVMPartitionRule(); + VLOG(1) << "Target " << primitive_target->ToDebugString() << " is for TVM spec_name " + << spec_name << " and has default partition rule:\n" + << rule->ToString(); + } + toolchain_to_rule.emplace(spec_name, rule); + toolchain_to_targets[spec_name].push_back(primitive_target); + } + + // Now group targets with their partition rules. + Array result; + for (const auto& kv : toolchain_to_rule) { + result.push_back(PartitionSpec(kv.first, toolchain_to_targets[kv.first], kv.second)); + } + + // Add one more spec to cover the host target. + result.push_back(PartitionSpec(kHostSpecName, {config->host_target}, MakeHostPartitionRule())); + + return result; +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/gather_partition_specs.h b/src/relay/collage/gather_partition_specs.h new file mode 100644 index 0000000000000..668a93b15634d --- /dev/null +++ b/src/relay/collage/gather_partition_specs.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/gather_partition_specs.h + * \brief Gather the relevant \p PartitionSpecs from the available \p Targets. + */ +#ifndef TVM_RELAY_COLLAGE_GATHER_PARTITION_SPECS_H_ +#define TVM_RELAY_COLLAGE_GATHER_PARTITION_SPECS_H_ + +#include + +#include "./partition_spec.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Distinguished partition spec names. + */ +constexpr const char* kTVMSpecName = "tvm"; +constexpr const char* kHostSpecName = "host"; + +/*! + * \brief The 'styles' of BYOC integrations. Used to influence how their corresponding + * partition rule is constructed. + */ +enum BYOCStyle { + /*! + * \brief The BYOC patterns pick out 'ideal' candidates directly, either because: + * - the BYOC toolchain does not perform any fusion so each matched sub-expression maps 1:1 to a + * BYOC-provided operator, or + * - the BYOC toolchain does perform fusion, however the patterns have been written to pick out + * fusable sub-graphs. + */ + kNoFusionBYOCStyle, + + /*! + * \brief The BYOC patterns pick out supported operators, but the BYOC backend may perform + * fusion over those operators in much the same way TVM does. + */ + kTVMFusionBYOCStyle, + + /*! + * \brief The BYOC patterns pick out supported operators, but the BYOC backend may perform + * arbitrary fusion over those operators. + */ + kArbitraryFusionBYOCStyle, +}; + +/*! + * \brief Returns all the partition specifications gathered from the \p Targets in \p config. + */ +Array GatherPartitionSpecs(const CompilationConfig& config); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_GATHER_PARTITION_SPECS_H_ diff --git a/src/relay/collage/index_set.cc b/src/relay/collage/index_set.cc new file mode 100644 index 0000000000000..a36da0147238f --- /dev/null +++ b/src/relay/collage/index_set.cc @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/index_set.cc + * \brief Efficient representation of a set of post-dfs indexes. + */ + +#include "./index_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +// TODO(mbs): These should operate one-word-at-a-time + +IndexSet::IndexSet(size_t size, const std::vector& indexes) : bitvec_(size, false) { + for (size_t index : indexes) { + ICHECK_LT(index, bitvec_.size()); + ICHECK(!bitvec_[index]); + bitvec_[index] = true; + } +} + +IndexSet IndexSet::operator&(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + std::vector result(bitvec_.size(), false); + for (size_t index = 0; index < bitvec_.size(); ++index) { + result[index] = bitvec_[index] && that.bitvec_[index]; + } + return IndexSet(result); +} + +IndexSet IndexSet::operator|(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + std::vector result(bitvec_.size(), false); + for (size_t index = 0; index < bitvec_.size(); ++index) { + result[index] = bitvec_[index] || that.bitvec_[index]; + } + return IndexSet(result); +} + +IndexSet IndexSet::operator-(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + std::vector result(bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); ++index) { + result[index] = bitvec_[index] && !that.bitvec_[index]; + } + return IndexSet(result); +} + +bool IndexSet::AreDisjoint(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index] && that.bitvec_[index]) { + return false; + } + } + return true; +} + +bool IndexSet::IsSubset(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index] && !that.bitvec_[index]) { + return false; + } + } + return true; +} + +bool IndexSet::Intersects(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index] && that.bitvec_[index]) { + return true; + } + } + return false; +} + +IndexSet IndexSet::Subst(size_t new_size, const IndexSubst& subst) const { + std::vector result(new_size, false); + for (PostDfsIndex index = 0; index < bitvec_.size(); ++index) { + if (!bitvec_[index]) { + continue; + } + auto itr = subst.find(index); + ICHECK(itr != subst.end()); + PostDfsIndex new_index = itr->second; + ICHECK(new_index < new_size); + ICHECK(!result[new_index]); + result[new_index] = true; + } + return IndexSet(result); +} + +size_t IndexSet::PopCount() const { + size_t n = 0; + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index]) { + ++n; + } + } + return n; +} + +bool IndexSet::IsZero() const { + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index]) { + return false; + } + } + return true; +} + +size_t IndexSet::FirstInsideIndex() const { + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index]) { + return index; + } + } + return bitvec_.size(); +} + +size_t IndexSet::LastInsideIndex() const { + for (size_t i = bitvec_.size(); i > 0; i--) { + const size_t index = i - 1; + if (bitvec_[index]) { + return index; + } + } + return bitvec_.size(); +} + +size_t IndexSet::NextIndex(size_t index) const { + ICHECK_LT(index, bitvec_.size()); + for (index++; index < bitvec_.size(); index++) { + if (bitvec_[index]) { + return index; + } + } + return bitvec_.size(); +} + +size_t IndexSet::FirstOutsideIndex() const { + for (size_t index = 0; index < bitvec_.size(); index++) { + if (!bitvec_[index]) { + return index; + } + } + return bitvec_.size(); +} + +bool IndexSet::operator==(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + return bitvec_ == that.bitvec_; +} + +bool IndexSet::operator!=(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + return bitvec_ != that.bitvec_; +} + +bool IndexSet::operator<(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index] && !that.bitvec_[index]) { + return true; + } + if (!bitvec_[index] && that.bitvec_[index]) { + return false; + } + } + return false; +} + +size_t IndexSet::hash() const { + std::hash> h; + return h(bitvec_); +} + +std::string IndexSet::ToString() const { + std::ostringstream os; + os << "{"; + bool first = true; + for (size_t start = 0; start < bitvec_.size(); /*no-op*/) { + if (!bitvec_[start]) { + ++start; + continue; + } + size_t end; + for (end = start + 1; end < bitvec_.size() && bitvec_[end]; ++end) { + /*no-op*/ + } + if (first) { + first = false; + } else { + os << ","; + } + os << start; + if (end > start + 2) { + os << ".." << (end - 1); + start = end; + } else { + ++start; + } + } + os << "}"; + return os.str(); +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/index_set.h b/src/relay/collage/index_set.h new file mode 100644 index 0000000000000..13e131f52cce1 --- /dev/null +++ b/src/relay/collage/index_set.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/index_set.h + * \brief Efficient representation of a set of post-dfs indexes. + */ + +#ifndef SRC_RELAY_COLLAGE_INDEX_SET_H_ +#define SRC_RELAY_COLLAGE_INDEX_SET_H_ + +#include + +#include "../ir/dataflow_matcher_impl.h" +#include "../ir/indexed_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +using IndexSubst = std::unordered_map; + +class IndexSet { + public: + IndexSet() = default; + explicit IndexSet(size_t size) : bitvec_(size, false) {} + IndexSet(size_t size, const std::vector& indexes); + + IndexSet operator&(const IndexSet& that) const; + IndexSet operator|(const IndexSet& that) const; + IndexSet operator-(const IndexSet& that) const; + bool AreDisjoint(const IndexSet& that) const; + bool IsSubset(const IndexSet& that) const; + bool Intersects(const IndexSet& that) const; + + bool operator[](size_t index) const { + ICHECK_LT(index, bitvec_.size()); + return bitvec_[index]; + } + + IndexSet& Add(size_t index) { + ICHECK_LT(index, bitvec_.size()); + bitvec_[index] = true; + return *this; + } + + IndexSet Subst(size_t new_size, const IndexSubst& subst) const; + + size_t end_index() const { return bitvec_.size(); } + size_t PopCount() const; + bool IsZero() const; + size_t FirstInsideIndex() const; + size_t LastInsideIndex() const; + size_t NextIndex(size_t index) const; + size_t FirstOutsideIndex() const; + bool operator==(const IndexSet& that) const; + bool operator!=(const IndexSet& that) const; + bool operator<(const IndexSet& that) const; + size_t hash() const; + std::string ToString() const; + + struct IndexSetIterator { + const IndexSet* set; + size_t i; + + size_t operator*() const { + ICHECK_LT(i, set->end_index()); + return i; + } + + const IndexSetIterator& operator++() { + ICHECK_LT(i, set->end_index()); + i = set->NextIndex(i); + return *this; + } + + bool operator==(const IndexSetIterator& that) const { + ICHECK(set == that.set); + return i == that.i; + } + + bool operator!=(const IndexSetIterator& that) const { + ICHECK(set == that.set); + return i != that.i; + } + }; + + IndexSetIterator begin() const { return IndexSetIterator{this, FirstInsideIndex()}; } + IndexSetIterator end() const { return IndexSetIterator{this, end_index()}; } + + private: + explicit IndexSet(std::vector bitvec) : bitvec_(std::move(bitvec)) {} + + std::vector bitvec_; +}; + +struct IndexSetEqual { + bool operator()(const IndexSet& left, const IndexSet& right) const { return left == right; } +}; + +struct IndexSetHash { + size_t operator()(const IndexSet& set) const { return set.hash(); } +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_INDEX_SET_H_ \ No newline at end of file diff --git a/src/relay/collage/name_supply.cc b/src/relay/collage/name_supply.cc new file mode 100644 index 0000000000000..040360212c703 --- /dev/null +++ b/src/relay/collage/name_supply.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/name_supply.cc + * \brief A source of fresh variable names. + */ + +#include "./name_supply.h" + +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +namespace { +void AppendCSafe(bool& first, std::ostringstream& os, const std::string& str) { + for (size_t i = 0; i < str.size(); ++i) { + const char c = str[i]; + if (i == 0 && first && (!std::isalpha(c) && c != '_')) { + os << "_"; + } + if (c == '_' || std::isalnum(c)) { + os << c; + } else { + os << "_"; + } + first = false; + } +} +} // namespace + +NameSupply NameSupply::MakeSubNameSupply() { + NameSupply result(prefix_); + for (const auto& kv : next_free_index_) { + result.next_free_index_.emplace(kv.first, kv.second); + } + return result; +} + +std::string NameSupply::Fresh(const std::initializer_list& hints) { + std::ostringstream os; + bool first = true; + bool need_sep = false; + if (!prefix_.empty()) { + AppendCSafe(first, os, prefix_); + need_sep = true; + } + for (const auto& hint : hints) { + if (need_sep) { + os << "_"; + } + AppendCSafe(first, os, hint); + need_sep = true; + } + std::string name = os.str(); + auto itr = next_free_index_.find(name); + if (itr == next_free_index_.end()) { + next_free_index_.emplace(name, 1); + } else { + os << "_" << itr->second++; + name = os.str(); + } + return name; +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/name_supply.h b/src/relay/collage/name_supply.h new file mode 100644 index 0000000000000..d8791bf5a4dec --- /dev/null +++ b/src/relay/collage/name_supply.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/name_supply.h + * \brief A source of fresh variable names. + */ + +#ifndef SRC_RELAY_COLLAGE_NAME_SUPPLY_H_ +#define SRC_RELAY_COLLAGE_NAME_SUPPLY_H_ + +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! \brief A supply of fresh names. */ +class NameSupply { + public: + explicit NameSupply(std::string prefix) : prefix_(std::move(prefix)) {} + + NameSupply MakeSubNameSupply(); + + void Reserve(const std::string& existing) { next_free_index_.emplace(existing, 1); } + + std::string Fresh(const std::initializer_list& hints); + + private: + /*! \brief Prefix for all names. May be empty. */ + std::string prefix_; + /*! \brief Next unused index for variables with given basename. */ + std::unordered_map next_free_index_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_NAME_SUPPLY_H_ diff --git a/src/relay/collage/partition_rule.cc b/src/relay/collage/partition_rule.cc new file mode 100644 index 0000000000000..ce1b9172fc955 --- /dev/null +++ b/src/relay/collage/partition_rule.cc @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_rule.cc + * \brief Compositional partitioning rules. + */ + +#include "./partition_rule.h" + +#include + +#include "./partition_rule.h" +#include "./partition_spec.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +std::vector PartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + ICHECK(false) << "PartitionRuleNode::AllCandidates should be overridden in sub-class"; + return {}; +} + +std::string PartitionRuleNode::ToString() const { return ToDoc().str(); } + +Doc PartitionRuleNode::ToDoc() const { + Doc doc; + doc << GetTypeKey() << "(" << Doc::NewLine(2); + std::vector body_items; + AppendBodyItems(body_items); + doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine(); + doc << ")"; + return doc; +} + +void PartitionRuleNode::AppendBodyItems(std::vector& body_items) const { + body_items.emplace_back(); + body_items.back() << "rule_name=" << Doc::StrLiteral(rule_name_); +} + +PartitionRule::PartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +bool DefaultPatternPredicate(const Expr& matched_sub_expr) { return true; } + +std::vector DFPatternPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running DFPatternPartitionRule(" << rule_name_ << ")"; + std::vector result; + DFPatternMatcher matcher(&dataflow_graph.indexed_graph()); + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); + if (!matcher.Match(pattern_, sub_expr)) { + continue; + } + if (!predicate_(sub_expr)) { + continue; + } + IndexSet inside = MatcherToIndexSet(matcher); + OpPatternKind kind; + String label; + std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); + String rule_name = rule_name_.empty() ? sub_graph->label_ : rule_name_; + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); + VLOG(2) << "DFPatternPartitionRule(" << rule_name_ << ") yields " << candidate->ToString(); + result.emplace_back(std::move(candidate)); + } + VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void DFPatternPartitionRuleNode::AppendBodyItems(std::vector& body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items.emplace_back(); + body_items.back() << "pattern=" << PrettyPrint(pattern_); +} + +DFPatternPartitionRule::DFPatternPartitionRule(String rule_name, DFPattern pattern, + TPatternPredicate predicate) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->pattern_ = std::move(pattern); + node->predicate_ = std::move(predicate); + data_ = std::move(node); +} + +std::vector CompositePartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running CompositePartitionRule(" << rule_name_ << ")"; + std::vector result; + FunctionAttrsMap attrs; + attrs.Set(attr::kComposite, rule_name_); + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs); + CandidatePartition new_candidate = WithSubGraph( + WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph)); + VLOG(2) << "CompositePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "CompositePartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void CompositePartitionRuleNode::AppendBodyItems(std::vector& body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items.emplace_back(); + body_items.back() << "sub_rule=" << sub_rule_->ToDoc(); +} + +CompositePartitionRule::CompositePartitionRule(String rule_name, PartitionRule sub_rule) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + data_ = std::move(node); +} + +std::vector PrimitivePartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running PrimitivePartitionRule(" << rule_name_ << ")"; + std::vector result; + FunctionAttrsMap attrs; + attrs.Set(attr::kPrimitive, Integer(1)); + std::unordered_set compilers; + for (const auto& target : spec->targets_) { + Optional opt_compiler = target->GetAttr("compiler", Optional()); + if (opt_compiler.defined()) { + attrs.Set(attr::kCompiler, opt_compiler.value()); + } + compilers.insert(opt_compiler.defined() ? opt_compiler.value() : ""); + } + ICHECK_EQ(compilers.size(), 1U); + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs); + CandidatePartition new_candidate = WithSubGraph( + WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph)); + VLOG(2) << "PrimitivePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "PrimitivePartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void PrimitivePartitionRuleNode::AppendBodyItems(std::vector& body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items.emplace_back(); + body_items.back() << "sub_rule=" << sub_rule_->ToDoc(); +} + +PrimitivePartitionRule::PrimitivePartitionRule(String rule_name, PartitionRule sub_rule) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + data_ = std::move(node); +} + +std::vector UnionPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector result; + for (const auto& sub_rule : sub_rules_) { + std::vector candidates = sub_rule->AllCandidates(dataflow_graph, spec); + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); + VLOG(2) << "UnionPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + } + VLOG(1) << "UnionPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates"; + return result; +} + +void UnionPartitionRuleNode::AppendBodyItems(std::vector& body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + for (const auto& sub_rule : sub_rules_) { + body_items.emplace_back(); + body_items.back() << "sub_rule=" << sub_rule->ToDoc(); + } +} + +UnionPartitionRule::UnionPartitionRule(String rule_name, Array sub_rules) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rules_ = std::move(sub_rules); + data_ = std::move(node); +} + +std::vector OpCallByKindPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running OpCallByKindPartitionRule(" << rule_name_ << ")"; + std::vector result; + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + auto node = dataflow_graph.index_to_node(index); + Expr sub_expr = node->ref(); + if (sub_expr->IsInstance()) { + OpPatternKind kind; + String label; + std::tie(kind, label) = SubExprKindAndLabel(sub_expr); + if (kind <= kOutEWiseFusable) { + IndexSet inside(dataflow_graph.size(), {index}); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); + String rule_name = NestLabels(rule_name_, sub_graph->label_); + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); + VLOG(2) << "OpCallByKindPartitionRule(" << rule_name_ << ") yields " + << candidate->ToString(); + result.emplace_back(std::move(candidate)); + } + } + } + VLOG(1) << "OpCallByKindPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void OpCallByKindPartitionRuleNode::AppendBodyItems(std::vector& body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); +} + +OpCallByKindPartitionRule::OpCallByKindPartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +std::vector CombinePartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + // We'll accumulate all the candidates here, starting with those from the sub-rule. + // Once a candidate is added to this vector it is immutable. + std::vector initial_candidates = + sub_rule_->AllCandidates(dataflow_graph, spec); + CandidateSet combiner_results; + VLOG(1) << "running CombinePartitionRule(" << rule_name_ << ")"; + for (const auto& candidate : initial_candidates) { + combiner_results.Add(dataflow_graph, candidate); + } + + size_t num_rounds = 0; + AppendAllResultsContext ctxt(&dataflow_graph, max_max_depth_, &combiner_results); + while (combiner_results.PrepareForNextRound()) { + VLOG_CONTEXT << "round " << ++num_rounds; + VLOG(1) << "checking " << combiner_results.current_candidates.size() << " candidates (" + << combiner_results.first_new_index << " existing)"; + for (const auto& combiner_rule : combiner_rules_) { + combiner_rule->AppendAllResults(ctxt); + } + } + + std::vector result; + for (auto& candidate : combiner_results.current_candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); + VLOG(2) << "CombineByPrimitivesPartitionRule(" << rule_name_ << ") yields " + << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "CombinePartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void CombinePartitionRuleNode::AppendBodyItems(std::vector& body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items.emplace_back(); + body_items.back() << "sub_rule=" << sub_rule_->ToDoc(); + for (const auto& combiner_rule : combiner_rules_) { + body_items.emplace_back(); + body_items.back() << "combiner_rule=" << combiner_rule->ToString(); + } + body_items.emplace_back(); + body_items.back() << "max_max_depth=" << max_max_depth_; +} + +CombinePartitionRule::CombinePartitionRule(String rule_name, PartitionRule sub_rule, + Array combiner_rules, + size_t max_max_depth_) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + node->combiner_rules_ = std::move(combiner_rules); + node->max_max_depth_ = max_max_depth_; + data_ = std::move(node); +} + +std::vector OnlyValidPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running OnlyValidPartitionRule(" << rule_name_ << ")"; + std::vector result; + for (auto& candidate : candidates) { + if (!candidate->sub_graph_->IsValid(dataflow_graph, config_)) { + VLOG(2) << "Ignoring invalid candidate " << candidate->ToString(); + continue; + } + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); + VLOG(2) << "OnlyValidPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "OnlyValidPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void OnlyValidPartitionRuleNode::AppendBodyItems(std::vector& body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items.emplace_back(); + body_items.back() << "sub_rule=" << sub_rule_->ToDoc(); + body_items.emplace_back(); + body_items.back() << "config=" << config_.ToString(); +} + +OnlyValidPartitionRule::OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, + const SubGraphConfig& config) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + node->config_ = config; + data_ = std::move(node); +} + +std::vector HostPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running HostPartitionRule(" << rule_name_ << ")"; + std::vector result; + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + if (MustBeLowered(dataflow_graph.index_to_node(index)->ref())) { + continue; + } + IndexSet inside(dataflow_graph.size(), {index}); + OpPatternKind kind; + String label; + std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, label); + String rule_name = NestLabels(rule_name_, sub_graph->label_); + // We'll assign both the target (unique from the spec) and cost (zero) for the candidate now + // since we'll never want to actually estimate the cost of this 'partition'. + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec, + spec->targets_.front(), Cost::Zero()); + VLOG(2) << "HostPartitionRule(" << rule_name_ << ") yields " << candidate->ToString(); + result.push_back(candidate); + } + VLOG(1) << "HostPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates"; + return result; +} + +void HostPartitionRuleNode::AppendBodyItems(std::vector& body_items) const {} + +HostPartitionRule::HostPartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/partition_rule.h b/src/relay/collage/partition_rule.h new file mode 100644 index 0000000000000..cbe69f577dc8d --- /dev/null +++ b/src/relay/collage/partition_rule.h @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_rule.h + * \brief Compositional partitioning rules. + */ + +#ifndef SRC_RELAY_COLLAGE_PARTITION_RULE_H_ +#define SRC_RELAY_COLLAGE_PARTITION_RULE_H_ + +#include +#include + +#include "../../printer/doc.h" +#include "./candidate_partition.h" +#include "./combiner_rule.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Type of function to check if a matched sub-expression should be accepted by a rule. This + * can be used to, eg, reject operators of unsupported shape or dtype, or otherwise implement rules + * which are difficult to express in the dataflow pattern language directly. + */ +using TPatternPredicate = TypedPackedFunc; + +/*! + * \brief The default pattern predicate. Always returns true. + */ +bool DefaultPatternPredicate(const Expr& matched_sub_expr); + +/*! + * \brief Base class of all partition rules. + * + * A \p PartitionRule describes how to find a set of \p CandidatePartitions for a \p DataflowGraph. + * The candidates are allowed to overlap, and ultimately it is the job of the Collage searcher to + * find a selection of candidates which covers the whole Relay expression without overlap. Partition + * rules are paired with their \p Target and other 'top level' configuration in a \p PartitionSpec. + * + * We provide a set of 'base' partition rules which produce candidates from the dataflow graph + * directly. We also provide a set of 'combinator' partition rules which can produce new candidates + * from the results of an arbitrary sub-rule or sub-rules. By mixing these base and combinator + * rules we can express a wide variety of partition strategies and encoding conventions. + * + * There may be many thousands of candidates in flight during the Collage search. We take care to + * defer constructing or rewriting Relay expressions until absolutely necessary. We only pay for + * extracting a function to represent a candidate when we need to measure it's cost. And we only + * pay for rewriting the overall Relay expression to commit to a partitioning when the Collage + * search has completed. + * + * The base rules implemented so far: + * - \p DFPatternPartitionRule: Given a \p DFPattern and expression predicate, produces a candidate + * for every sub-graph matched by the pattern and predicate. Unlike the \p PatternRewriter, + * candidates are free to overlap. Used to bring BYOC patterns into the Collage framework. + * - \p OpCallByKindPartitionRule: Uses the "TOpPattern" attribute provided for every Relay + * operator to produce a candidate for every call to a 'fusable Relay operator'. Used to + * look ahead to how TVM will fuse sub-graphs. + * + * The combinator rules implemented so far: + * - \p CompositePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped + * by a "Composite" function. The "Composite" name is taken from the rule name. Used to indicate + * Relay operators (or groups of Relay operators) should be mapped to target-specific operators, + * both for BYOC and TVM external library integrations. + * - \p PrimitivePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped + * by a "Primitive" function, possibly with an additional "Compiler" attribute. Used to + * delineate a partition (or kernel). + * - \p UnionPartitionRule: Simply unions all the candidates from all sub-rules together. Used to + * combine individual \p DFPatternPartitionRules. + * - \p CombinePartitionRule: Given a sub-rule and a list of 'combiner' rules, finds + * all possible ways of combining the sub-rule's candidates to yield even larger candidates. + * Note that the sub-rule's candidates may also be directly included in the results. The + * 'combiner' rules allow combining by \p OpPatternKinds, combining the arguments to tuples + * which themselves are arguments to Relay operator calls, and so on. This rule is intended to + * mimic the existing TVM \p FuseOps pass, though: + * i) all candidates are found rather than just the largest, ii) the starting set of candidates + * can be provided by any other rule, and iii) we rely on \p SubGraph validity checking to weed + * out infeasible candidates. + * - \p OnlyValidPartitionRule: Given a \p SubGraphConfig, ignores candidates with 'invalid' + * sub-graphs. Used to limit the maximum candidate depth, the number of independent outputs, + * and whether intermediate 'taps' are allowed. + * - \p HostPartitionRule: Produces candidates for all Relay expressions which could be + * 'left behind' for execution by the host (eg on the VM). This rule lets us simplify the + * overall Collage search algorithm. + * + * (Though not yet implemented, we'd like to allow a combinator rule which will union candidate + * based on their 'anchor' operators. This can be used to implement 'vertical' and 'horizontal' + * partition on more primitive candidates. Note that the \p SubGraph machinery supports + * multiple-input and -output sub-graphs and their validation, so horizontal partition is easy + * implement.) + * + * Here are some typical ways to combine \p PartitionRules for different partition/fusion + * strategies: + * + * - Classic pattern-based BYOC with \p MergeComposite/AnnotateTarget/PartitionGraph passes: + * \code + * PrimitivePartitionRule + * OnlyValidPartitionRule + * CombinePartitionRule (with join-anything combiner rule) + * UnionPartitionRule + * CompositePartitionRule(label1) + * DFPatternPartitionRule(pattern1) + * : + * CompositePartitionRule(labeln) + * DFPatternPartitionRule(patternn) + * \endcode + * + * - "Consider this library implementation for these sub-expressions", using \p DFPatterns to + * pick out which Relay operators are supported: + * \code + * OnlyValidPartitionRule + * CombinePartitionRule (with default TVM combiner rules) + * UnionPartitionRule + * OpCallByKindPartitionRule + * CompositePartitionRule(lable1) + * DFPatternPartitionRule(pattern1) + * : + * CompositePartitionRule(lablen) + * DFPatternPartitionRule(patternn) + * \endcode + * + * - Classic TVM \p FuseOps + * \code + * PrimitivePartitionRule + * OnlyValidPartitionRule + * CombinePartitionRule (with default TVM combiner rules) + * OpCallByKindPartitionRule + * \endcode + * + * - "Just fuse what I tell you to fuse", using \p DFPatterns to directly select candidates: + * \code + * PrimitivePartitionRule + * OnlyValidPartitionRule + * UnionPartitionRule + * DFPatternPartitionRule(pattern1) + * : + * DFPatternPartitionRule(patternn) + * \endcode + */ +class PartitionRuleNode : public Object { + public: + /*! + * \brief A unique (over all rules for the same target) name for the rule. Rule names are + * combined and captured with \p PartitionCandidate rule names for debuggability and + * explainability. Some rules will copy the rule name into function attributes. + * + */ + String rule_name_; + + void VisitAttrs(AttrVisitor* v) { v->Visit("rule_name", &rule_name_); } + + /*! + * \brief Returns all the possible candidate partitions according to this rule for the overall + * expression corresponding to \p dataflow_graph. The candidates will generally have unknown + * target and cost: the target will be filled in by the \p PartitionSpec, while the cost will + * be filled in lazily. + */ + virtual std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const; + + std::string ToString() const; + Doc ToDoc() const; + + protected: + virtual void AppendBodyItems(std::vector& body_items) const; + + public: + static constexpr const char* _type_key = "relay.collage.PartitionRule"; + static constexpr const uint32_t _type_child_slots = 9; + TVM_DECLARE_BASE_OBJECT_INFO(PartitionRuleNode, Object); +}; + +class PartitionRule : public ObjectRef { + public: + explicit PartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(PartitionRule, ObjectRef, PartitionRuleNode); +}; + +/*! + * \brief Partition rule which fires on all sub-expressions matching a dataflow-pattern and pattern + * predicate. It is valid for matching candidates to overlap. + */ +class DFPatternPartitionRuleNode : public PartitionRuleNode { + public: + /*! + * \brief Relay pattern. + */ + DFPattern pattern_; + + /*! + * \brief Predicate on matched sub-expression to decide if partition rule should fire. + */ + TPatternPredicate predicate_; + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector& body_items) const override; + + static constexpr const char* _type_key = "relay.collage.DFPatternPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(DFPatternPartitionRuleNode, PartitionRuleNode); +}; + +class DFPatternPartitionRule : public PartitionRule { + public: + DFPatternPartitionRule(String rule_name, DFPattern pattern, + TPatternPredicate predicate = DefaultPatternPredicate); + + TVM_DEFINE_OBJECT_REF_METHODS(DFPatternPartitionRule, PartitionRule, DFPatternPartitionRuleNode); +}; + +/*! + * \brief Partition rule which wraps candidates within a function with the "Composite" attribute + * bound to the given rule name. + * + * This is the standard way by which operators or operator groups are tagged as being supported + * by a particular externally provided function. It is up to the BYOC lowering function to + * recognize the "Composite" name and emit the appropriate code or call. + */ +class CompositePartitionRuleNode : public PartitionRuleNode { + public: + /*! \brief The sub-partition rule. */ + PartitionRule sub_rule_; + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector& body_items) const override; + + static constexpr const char* _type_key = "relay.collage.CompositePartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(CompositePartitionRuleNode, PartitionRuleNode); +}; + +class CompositePartitionRule : public PartitionRule { + public: + CompositePartitionRule(String rule_name, PartitionRule sub_rule); + + TVM_DEFINE_OBJECT_REF_METHODS(CompositePartitionRule, PartitionRule, CompositePartitionRuleNode); +}; + +/*! + * \brief Partition rule which wraps candidates within a function with the "Primitive" attribute + * bound to 1. If the partition spec target(s) have the "compiler" attribute then that name is + * also added to the function as a "Compiler" attribute. + * + * This is the standard way by which sub-graphs are marked as being in a 'partition' who's + * compilation will be managed by an external BYOC toolchain. It can also be used to mark + * sub-graphs for lowering to a single kernel by the built-in TVM lowering machinery. + */ +class PrimitivePartitionRuleNode : public PartitionRuleNode { + public: + /*! \brief The sub-partition rule. */ + PartitionRule sub_rule_; + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector& body_items) const override; + + static constexpr const char* _type_key = "relay.collage.PrimitivePartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimitivePartitionRuleNode, PartitionRuleNode); +}; + +class PrimitivePartitionRule : public PartitionRule { + public: + PrimitivePartitionRule(String rule_name, PartitionRule sub_rule); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimitivePartitionRule, PartitionRule, PrimitivePartitionRuleNode); +}; + +/*! + * \brief Partition rule which simply unions all matches from all sub-partition rules. + * + * This can be used to combine the results of a set of, eg, DFPatternPartitionRules. + */ +class UnionPartitionRuleNode : public PartitionRuleNode { + public: + Array sub_rules_; + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector& body_items) const override; + + static constexpr const char* _type_key = "relay.collage.UnionPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnionPartitionRuleNode, PartitionRuleNode); +}; + +class UnionPartitionRule : public PartitionRule { + public: + UnionPartitionRule(String rule_name, Array sub_rules); + + TVM_DEFINE_OBJECT_REF_METHODS(UnionPartitionRule, PartitionRule, UnionPartitionRuleNode) +}; + +/* + *! \brief Partition rule which places calls to Relay operators with a "TOpPattern" attribute of + * \p kOutEWiseFusable or less in their own singleton sub-graph. No other Relay sub-expressions + * (such as tuples or tuple projection) are selected, and it is up to outer partition rules to + * account for them. + */ +class OpCallByKindPartitionRuleNode : public PartitionRuleNode { + public: + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector& body_items) const override; + + static constexpr const char* _type_key = "relay.collage.OpCallByKindPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(OpCallByKindPartitionRuleNode, PartitionRuleNode); +}; + +class OpCallByKindPartitionRule : public PartitionRule { + public: + explicit OpCallByKindPartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(OpCallByKindPartitionRule, PartitionRule, + OpCallByKindPartitionRuleNode); +}; + +/*! + * \brief Partition rule which combines sub-graphs to exploit optimizations commonly available in + * backends (including the TVM lowering backend). Those optimization rules are in turn described by + * one or more \p PrimRules. + * + * For TVM these primitive rules are guided by the \p OpPatternKind associated with every sub-graph. + * That in turn is the maximum of the kind of each expression node in the sub-graph, using the + * rules: + * - Constants are \p kElemwise. + * - A call to a Relay operator has the kind of its callee. + * - Tuple construction and projection are injective provided all tuple fields are of tensor type. + * - All other sub-expressions are opaque. + * + * The available \p OpPatternKinds (and our abbreviations for them) are: + * - E: kElemWise, eg nn.relu + * - B: kBroadcast, eg add + * - I: kInjective, eg concatenate + * - R: kCommReduce, eg sum + * - A: kOutEWiseFusable, eg nn.conv2d (often called 'anchor nodes', hence the A abbreviation) + * - O: kOpaque, everything else + * (The kTuple kind is not used by this machinery.) + * + * Kinds are ordered as above from least- to most-constraining w.r.t. possible partition + * opportunities. When we write a kind abbreviation below we intend it to mean that kind *or less*. + * And when when write 'kl -> kr' we mean it to match a sub-expression of kind kr or less who's + * dataflow inputs are all of kind kl or less. + * + * We can then mimic the classic \p FuseOps TVM Pass with the following more primitive 'combiner' + * rules: + * - Sub-groups cannot have taps. In the classic \p FuseOps pass taps are avoided by construction + * by always considering all node->dominator paths. Here we naively allow taps on all candidates, + * but reject them using SubGraph::IsValid with a SubGraphConfig with allow_taps = false. + * - Combine A -> B + * - Combine B -> R + * - Combine I -> I + * - Combine I -> tuple -> I. That is, if an I sub-graph has a tuple as input, and at least one + * tuple field can be provided by an I sub-graph exit, then both the tuple and all such fields + * may be joined. + * + * Note that \p FuseOps only considers the largest possible sub-graphs. However this partition rule + * considers all possibilities so as to 'make room' for other targets supplying other + * overlapping candidates. + * + * See combiner_rule.h for the more primitive combiner rules which implement the above. + */ +class CombinePartitionRuleNode : public PartitionRuleNode { + public: + /*! \brief The sub-rule supplying the initial set of candidates. */ + PartitionRule sub_rule_; + /*! \brief The more primitive rules to use to combine the candidates found by the above rule. */ + Array combiner_rules_; + /*! \brief Maximum max_depth for candidates. */ + size_t max_max_depth_; + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector& body_items) const override; + + public: + static constexpr const char* _type_key = "relay.collage.CombinePartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(CombinePartitionRuleNode, PartitionRuleNode); +}; + +class CombinePartitionRule : public PartitionRule { + public: + CombinePartitionRule(String rule_name, PartitionRule sub_rule, Array combiner_rules, + size_t max_max_depth_); + + TVM_DEFINE_OBJECT_REF_METHODS(CombinePartitionRule, PartitionRule, CombinePartitionRuleNode); +}; + +/*! + * \brief Partition rules which keeps only candidates from the sub-rule whose sub-groups are valid + * w.r.t. the given \p SubGraphConfig. + */ +class OnlyValidPartitionRuleNode : public PartitionRuleNode { + public: + PartitionRule sub_rule_; + SubGraphConfig config_; + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector& body_items) const override; + + public: + static constexpr const char* _type_key = "relay.collage.OnlyValidPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(OnlyValidPartitionRuleNode, PartitionRuleNode); +}; + +class OnlyValidPartitionRule : public PartitionRule { + public: + OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, const SubGraphConfig& config); + + TVM_DEFINE_OBJECT_REF_METHODS(OnlyValidPartitionRule, PartitionRule, OnlyValidPartitionRuleNode); +}; + +/*! + * \brief Partition rule which selects nodes which can be 'left behind' to be executed by the host + * (eg on the VM). This includes most of the 'interstitial' Relay constructs, such a let bindings, + * operators on references, calls to non-operator functions, and so on. It can also include the + * construction of and projection from tuples which may not be supported within a partition. + */ +class HostPartitionRuleNode : public PartitionRuleNode { + public: + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector& body_items) const override; + + public: + static constexpr const char* _type_key = "relay.collage.HostPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(HostPartitionRuleNode, PartitionRuleNode); +}; + +class HostPartitionRule : public PartitionRule { + public: + explicit HostPartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(HostPartitionRule, PartitionRule, HostPartitionRuleNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_PARTITION_RULE_H_ diff --git a/src/relay/collage/partition_spec.cc b/src/relay/collage/partition_spec.cc new file mode 100644 index 0000000000000..f0c963618d146 --- /dev/null +++ b/src/relay/collage/partition_spec.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_spec.cc + * \brief Combine a \p PartitionRule with one or more \p Targets. + */ + +#include "./partition_spec.h" + +#include + +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +String DefaultValidateSubGraphFunc(const Function& function) { return String(); } + +PartitionSpec::PartitionSpec(String spec_name, Array targets, PartitionRule rule, + TValidateSubGraphFunc validate_sub_graph_func) { + auto node = runtime::make_object(); + node->spec_name_ = std::move(spec_name); + node->targets_ = std::move(targets); + node->rule_ = std::move(rule); + node->validate_sub_graph_func_ = std::move(validate_sub_graph_func); + data_ = std::move(node); +} + +std::vector PartitionSpecNode::AllCandidates( + const DataflowGraph& dataflow_graph) const { + // Gather all the candidates. They'll have no target, function or cost at this stage. + std::vector candidates = + rule_->AllCandidates(dataflow_graph, GetRef(this)); + std::vector result; + for (const auto& candidate : candidates) { + ICHECK_EQ(candidate->spec_, GetRef(this)); + // Emit a copy of the candidate for each possible target. + for (const auto& target : targets_) { + String rule_name = NestLabels(spec_name_, candidate->rule_name_); + CandidatePartition new_candidate = + WithTarget(WithRuleName(candidate, std::move(rule_name)), target); + result.emplace_back(std::move(new_candidate)); + } + } + return result; +} + +std::string PartitionSpecNode::ToString() const { + Doc doc; + doc << "PartitionSpec(" << Doc::NewLine(2); + std::vector body_items; + body_items.emplace_back(); + body_items.back() << "spec_name=" << Doc::StrLiteral(spec_name_); + for (const auto& target : targets_) { + body_items.emplace_back(); + body_items.back() << "target=" << target->ToDebugString(); + } + body_items.emplace_back(); + body_items.back() << "rule=" << rule_->ToDoc(); + doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine(); + doc << ")"; + return doc.str(); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/partition_spec.h b/src/relay/collage/partition_spec.h new file mode 100644 index 0000000000000..d23934946e8be --- /dev/null +++ b/src/relay/collage/partition_spec.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_spec.h + * \brief Combine a \p PartitionRule with one or more \p Targets. + */ + +#ifndef SRC_RELAY_COLLAGE_PARTITION_SPEC_H_ +#define SRC_RELAY_COLLAGE_PARTITION_SPEC_H_ + +#include +#include +#include + +#include "./partition_rule.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Type of functions for checking the validity of partitions before they proceed to lowering + * and codegen. The argument is the function extracted from the overall expression to represent + * the partition. The result is a non-empty error message string if the candidate should be + * rejected. + */ +using TValidateSubGraphFunc = TypedPackedFunc; + +/*! + * \brief The default validation function. Always returns the empty string, ie no error. + */ +String DefaultValidateSubGraphFunc(const Function& function); + +/*! + * \brief Pairs a \p PartitionRule with one or more \p Targets it can be used for. + */ +class PartitionSpecNode : public Object { + public: + /*! + * \brief Specification name to distinguish this spec from all others. Typically the BYOC + * 'compiler' name, "tvm", or "host". + */ + String spec_name_; + + /*! + * \brief The targets all candidate partitions should be compiled for. It is possible for multiple + * target to share the same partition rules, eg if we are targeting multiple devices. + */ + Array targets_; + + /*! + * \brief The partition rule to use to gather candidates. + */ + PartitionRule rule_; + + /*! + * \brief The validation function to apply to each candidate's the extracted function before + * proceeding to lowering/codegen. + */ + TValidateSubGraphFunc validate_sub_graph_func_ = DefaultValidateSubGraphFunc; + + /*! + * \brief Returns all the candidate partitions found by this specification. The candidates + * will be for a specific target, but will not yet have an extracted function or cost. + */ + std::vector AllCandidates(const DataflowGraph& dataflow_graph) const; + + std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.PartitionSpec"; + TVM_DECLARE_FINAL_OBJECT_INFO(PartitionSpecNode, Object); +}; + +class PartitionSpec : public ObjectRef { + public: + PartitionSpec(String spec_name, Array targets, PartitionRule rule, + TValidateSubGraphFunc validate_sub_graph_func = DefaultValidateSubGraphFunc); + + TVM_DEFINE_OBJECT_REF_METHODS(PartitionSpec, ObjectRef, PartitionSpecNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_PARTITION_SPEC_H_ diff --git a/src/relay/collage/priority_queue.h b/src/relay/collage/priority_queue.h new file mode 100644 index 0000000000000..1f2eb45dfdf11 --- /dev/null +++ b/src/relay/collage/priority_queue.h @@ -0,0 +1,73 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/*! +* \file src/relay/collage/priority_queue.h +* \brief An updatable priority queue. +*/ + +#ifndef SRC_RELAY_COLLAGE_PRIORITY_QUEUE_H_ +#define SRC_RELAY_COLLAGE_PRIORITY_QUEUE_H_ + +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! \brief Priority queue of search states, ordered by increasing cost. */ +template +class PriorityQueue { + public: + PriorityQueue() = default; + + /*! \brief Pushes \p item onto the queue. */ + void Push(T* item) { set_.emplace(item); } + + /*! \brief Pops the item with the least cost off the queue. */ + T* Pop() { + ICHECK(!set_.empty()); + T* item = *set_.begin(); + set_.erase(set_.begin()); + return item; + } + + /*! \brief Updates the queue to account for \p item's best cost being lowered. */ + void Update(T* item) { + auto itr = std::find_if(set_.begin(), set_.end(), [item](const T* that) { + return EqTPtr()(that, item); + }); + ICHECK(itr != set_.end()); + set_.erase(itr); + set_.emplace(item); + } + + bool empty() const { return set_.empty(); } + size_t size() const { return set_.size(); } + + private: + // TODO(mbs): Actually use a pri-queue datastructure! + std::set set_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_PRIORITY_QUEUE_H_ diff --git a/src/relay/collage/prune_candidates.cc b/src/relay/collage/prune_candidates.cc new file mode 100644 index 0000000000000..608014ae2b842 --- /dev/null +++ b/src/relay/collage/prune_candidates.cc @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/prune_candidates.cc + * \brief Try to remove candidates which will never contribute to an optimal partitioning. + */ + +#include "./prune_candidates.h" + +#include "./dataflow_graph.h" +#include "./gather_partition_specs.h" + +namespace tvm { +namespace relay { +namespace collage { + +namespace { + +/*! + * \brief Returns a map from post-dfs dataflow node indices to the indices within \p candidates for + * those candidates which intersect that dataflow node. + * + * NOTE: The index set in the vector results is over candidate indices not post-dfs indices! + */ +std::vector MakeInsideMap(const DataflowGraph& dataflow_graph, + const std::vector& candidates) { + std::vector result(dataflow_graph.size(), IndexSet(candidates.size())); + for (size_t i = 0; i < candidates.size(); ++i) { + CandidatePartition candidate = candidates[i]; + for (PostDfsIndex index : candidate->sub_graph_->inside_) { + result[index].Add(i); + } + } + return result; +} + +/*! + * \brief Returns the maximal candidates within \p candidates. A candidate is maximal if it is not + * contained by any super-candidate for the same target. + */ +std::vector MaximalCandidates( + const DataflowGraph& dataflow_graph, const std::vector& candidates) { + std::vector inside_map = MakeInsideMap(dataflow_graph, candidates); + std::vector result; + for (size_t i = 0; i < candidates.size(); ++i) { + CandidatePartition maximal_candidate = candidates[i]; + bool has_super_candidate = false; + IndexSet explored_candidates(candidates.size()); // over candidates! + for (PostDfsIndex index : maximal_candidate->sub_graph_->inside_) { + for (size_t j : inside_map[index]) { + if (i == j) { + // Ignore self. + continue; + } + if (explored_candidates[j]) { + // Already checked. + continue; + } + explored_candidates.Add(j); + CandidatePartition super_candidate = candidates[j]; + if (maximal_candidate->target_ == super_candidate->target_ && + maximal_candidate->sub_graph_->inside_.IsSubset(super_candidate->sub_graph_->inside_)) { + has_super_candidate = true; + break; + } + } + if (has_super_candidate) { + break; + } + } + if (!has_super_candidate) { + VLOG(2) << "Found maximal candidate " << maximal_candidate->ToString(); + result.emplace_back(maximal_candidate); + } + } + VLOG(1) << "Have " << result.size() << " maximal candidates"; + return result; +} + +/*! + * \brief Returns all the candidates in \p candidates which intersect without being equal. + */ +std::vector IntersectingCandidates( + const DataflowGraph& dataflow_graph, std::vector& candidates) { + std::vector inside_map = MakeInsideMap(dataflow_graph, candidates); + IndexSet intersecting(candidates.size()); // over candidates! + for (size_t i = 0; i < candidates.size(); ++i) { + CandidatePartition intersecting_candidate = candidates[i]; + IndexSet explored_candidates(candidates.size()); // over candidates! + for (PostDfsIndex index : intersecting_candidate->sub_graph_->inside_) { + for (size_t j : inside_map[index]) { + if (j < i) { + // Intersection is commutative. + continue; + } + if (i == j) { + // Ignore self. + continue; + } + if (explored_candidates[j]) { + // Already checked. + continue; + } + explored_candidates.Add(j); + CandidatePartition other_candidate = candidates[j]; + if (intersecting_candidate->sub_graph_->inside_ == other_candidate->sub_graph_->inside_) { + // Have same inside set. + continue; + } + VLOG(2) << "Candidate " << intersecting_candidate->ToString() << " intersects with " + << other_candidate->ToString(); + intersecting.Add(i); + intersecting.Add(j); + } + } + } + std::vector result; + for (size_t i : intersecting) { + CandidatePartition candidate = candidates[i]; + VLOG(2) << "Found intersecting candidate " << candidate->ToString(); + result.emplace_back(candidate); + } + VLOG(1) << "Have " << result.size() << " intersecting candidates"; + return result; +} + +/*! + * \brief Returns the set operation left - right. + */ +std::vector SetDifference(const std::vector& left, + const std::vector& right) { + std::unordered_set + right_set(right.begin(), right.end()); + std::vector result; + for (const auto& candidate : left) { + if (right_set.count(candidate) == 0) { + result.emplace_back(candidate); + } + } + return result; +} + +/*! + * \brief Adds everything in right to left. Returns the number of elements added. + */ +size_t SetUnionInPlace( + std::unordered_set& left, + const std::vector& right) { + size_t init_size = left.size(); + for (const auto& candidate : right) { + left.emplace(candidate); + } + return left.size() - init_size; +} + +} // namespace + +std::vector PruneCandidates( + const DataflowGraph& dataflow_graph, + const std::vector& initial_candidates) { + VLOG_CONTEXT << "prune"; + // Start with all candidates available. + std::vector candidates = initial_candidates; + std::unordered_set pruned; + size_t num_rounds = 0; + while (true) { + VLOG_CONTEXT << "round " << ++num_rounds; + VLOG(1) << "checking " << candidates.size() << " candidates"; + // Add all the maximal candidates to the pruned set. + std::vector maximal_candidates = + MaximalCandidates(dataflow_graph, candidates); + size_t num_new_pruned = SetUnionInPlace(pruned, maximal_candidates); + VLOG(1) << "Added " << num_new_pruned << " new pruned candidates"; + if (num_new_pruned == 0) { + // We've reached a fixed point. + break; + } + // If two pruned candidates intersect without being equal then we may miss valid + // paths during search. So remove those intersecting candidates from the available candidates + // and try again so as to find smaller candidates to 'bridge the gaps'. + std::vector pruned_vec(pruned.begin(), pruned.end()); + std::vector intersecting_candidates = + IntersectingCandidates(dataflow_graph, pruned_vec); + // We need more maximal candidates to fill in the gaps between the current pruned candidates. + // Force that by removing the intersecting candidates from the set of available candidates + // and going around again. + candidates = SetDifference(candidates, intersecting_candidates); + } + + VLOG(1) << "Have " << pruned.size() << " pruned candidates"; + std::vector result(pruned.begin(), pruned.end()); + // Re-establish a canonical order of candidates. + std::sort(result.begin(), result.end()); + return result; +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/prune_candidates.h b/src/relay/collage/prune_candidates.h new file mode 100644 index 0000000000000..97db0fdb60691 --- /dev/null +++ b/src/relay/collage/prune_candidates.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/prune_candidates.h + * \brief Try to remove candidates which will never contribute to an optimal partitioning. + */ + +#ifndef SRC_RELAY_COLLAGE_PRUNE_CANDIDATES_H_ +#define SRC_RELAY_COLLAGE_PRUNE_CANDIDATES_H_ + +#include + +#include "./candidate_partition.h" +#include "./dataflow_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Returns \p initial_candidates with all unnecessary candidates pruned. + * + * We prune according to the following two heuristics: + * 1. Given partitions (A, target) and (B, target) then + * cost(A union B, target) < cost(A, target) + cost(B, target). + * That is, there's no use estimating the cost of small partitions when a larger partition + * containing them is also available. More precisely, call a partition 'maximal' if it is + * not contained by any other partition for the same target. Then we want to prefer maximal + * candidates when searching. + * 2. Given maximal partitions (A union B, target) and (A union B, target') where + * target != target', then min(cost(A union B, target), cost(A union B, target')) < + * min(cost(A, target) + cost(B, target'), cost(A, target') + cost(B, target)). + * That is, there's no use estimating cross-combinations of partitions which are not maximal. + * + * However, we can't prune a non-maximal candidate if it will make some other maximal candidate + * unreachable during the Collage search. We achieve this by iterating until fixed point: + * - Find maximal candidates of current set of candidates. + * - Add those maximal candidates to the output 'pruned' set. + * - If any two candidates in the 'pruned' set intersect without being equal, remove those from + * the current set of candidates and go around again. That will force more candidates to + * be considered 'maximal'. + * That over-approximates the true necessary candidates but is at least simple. + */ +std::vector PruneCandidates( + const DataflowGraph& dataflow_graph, const std::vector& initial_candidates); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_PRUNE_CANDIDATES_H_ diff --git a/src/relay/collage/recover_virtual_device_map.cc b/src/relay/collage/recover_virtual_device_map.cc new file mode 100644 index 0000000000000..e078903c652b6 --- /dev/null +++ b/src/relay/collage/recover_virtual_device_map.cc @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/recover_virtual_device_map.cc + * \brief Recover the virtual device for every Relay expression node. + */ + +#include "./recover_virtual_device_map.h" + +#include "../transforms/device_aware_visitors.h" + +namespace tvm { +namespace relay { +namespace collage { + +std::unordered_map RecoverVirtualDeviceMap(const IRModule& mod, + const Expr& expr) { + class Visitor : public transform::DeviceAwareExprVisitor { + public: + Visitor(const Optional& maybe_mod) : transform::DeviceAwareExprVisitor(maybe_mod) {} + + void VisitExpr(const Expr& expr) final { + map_[expr.get()] = GetVirtualDevice(expr); + transform::DeviceAwareExprVisitor::VisitExpr(expr); + } + + std::unordered_map map_; + }; + + Visitor visitor(mod); + visitor.VisitExpr(expr); + return std::move(visitor.map_); +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/recover_virtual_device_map.h b/src/relay/collage/recover_virtual_device_map.h new file mode 100644 index 0000000000000..886f67f158e0f --- /dev/null +++ b/src/relay/collage/recover_virtual_device_map.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/collage/recover_virtual_device_map.h + * \brief Recover the virtual device for every Relay expression node. + * + * Temporary hack until virtual_device_ work is finished. + */ +#ifndef TVM_RELAY_COLLAGE_RECOVER_VIRTUAL_DEVICE_MAP_H_ +#define TVM_RELAY_COLLAGE_RECOVER_VIRTUAL_DEVICE_MAP_H_ + +#include + +namespace tvm { +namespace relay { +namespace collage { + +std::unordered_map RecoverVirtualDeviceMap(const IRModule& mod, + const Expr& expr); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_RECOVER_VIRTUAL_DEVICE_MAP_H_ diff --git a/src/relay/collage/sub_graph.cc b/src/relay/collage/sub_graph.cc new file mode 100644 index 0000000000000..914be363b1539 --- /dev/null +++ b/src/relay/collage/sub_graph.cc @@ -0,0 +1,958 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/sub_graph.cc + * \brief Represents a sub-graph of an overall Relay expression. + */ + +#include "./sub_graph.h" + +#include + +#include "../transforms/pass_utils.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +namespace { + +class Extractor; + +/*! + * \brief Helper class for rewriting expressions to replace a sub-graph according to the + * given extractor. + */ +class Rewriter : public ExprMutator { + public: + explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {} + + Expr VisitExpr(const Expr& expr) final; + + private: + /*! \brief Already prepared extractor which will guide the rewrite. */ + const Extractor* extractor_; +}; + +/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */ +class Extractor : public ExprMutator { + public: + Extractor(const DataflowGraph* dataflow_graph, NameSupply* name_supply, + const SubGraphNode* sub_graph, FunctionAttrsMap opt_attrs) + : dataflow_graph_(dataflow_graph), + name_supply_(name_supply), + sub_graph_(sub_graph), + opt_attrs_(std::move(opt_attrs)) { + ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size()); + } + + const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; } + + /*! + * \brief Collect the parameters and output expressions for the function representing + * the sub-graph. + */ + void Extract() { + ICHECK(!sub_graph_->IsEmpty()); + VLOG(2) << "Extracting " << sub_graph_->ToString(); + const bool for_function = opt_attrs_.defined(); + + // In reverse dataflow order... + for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) { + PostDfsIndex index = i - 1; + if (!sub_graph_->inside_[index]) { + // Node is outside sub-graph. + continue; + } + VLOG(2) << "index " << index; + auto node = dataflow_graph_->index_to_node(index); + if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) { + // This sub-expression is: + // - inside the sub-graph and needed outside the sub-graph. So it must contribute to an + // output (even if we've already visited it while constructing an output from a + // downstream sub-expression). + // - not yet visited, in which case it must still be considered an 'output' so it will + // be evaluated for any possible side effects. + Expr output = VisitExpr(GetRef(node->node_ref_)); + VLOG(2) << "index " << index << " added as output:\n" + << PrettyPrint(output) << "\nat " << outputs_.size(); + expr_to_output_index_.emplace(node->node_ref_, outputs_.size()); + outputs_.emplace_back(std::move(output)); + if (for_function) { + output_types_.emplace_back(node->node_ref_->checked_type()); + } + } + } + ICHECK(!outputs_.empty()); + + // Reverse the outputs so as to preserve the original evaluation order. + std::reverse(outputs_.begin(), outputs_.end()); + std::reverse(output_types_.begin(), output_types_.end()); + for (auto& kv : expr_to_output_index_) { + kv.second = static_cast(outputs_.size()) - 1 - kv.second; + } + + // Build a 'body' expression to represent the extracted sub-graph. If we have multiple + // outputs we'll place them in a tuple. + Expr body = outputs_.size() > 1 ? Tuple(outputs_) : outputs_.front(); + + // Re-express all the sub-sub-graphs in terms of the body. + DataflowGraph body_dataflow_graph(body); + std::vector sub_sub_graphs; + IndexSubst subst = MakeIndexSubst(body_dataflow_graph); + for (const auto& sub_sub_graph : sub_graph_->sub_sub_graphs_) { + sub_sub_graphs.emplace_back(sub_sub_graph.Subst(body_dataflow_graph, subst)); + } + + // Sweep backwards through the body, rewriting to account for each sub-sub-graph. + body = SubSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(sub_sub_graphs), + *name_supply_); + + if (for_function) { + Type ret_type = outputs_.size() > 1 ? TupleType(output_types_) : output_types_.front(); + FunctionAttrsMap attrs(opt_attrs_); + if (attrs.count(attr::kCompiler)) { + std::string compiler = Downcast(attrs.Get(attr::kCompiler).value_or(String(""))); + std::string label = sub_graph_->label_; + // Assign a unique global symbol name. + attrs.Set(tvm::attr::kGlobalSymbol, String(name_supply_->Fresh({compiler, label}))); + } + // Rewrite so all input nodes are now conveyed via call arguments to a new function. + extracted_ = Function(std::move(params_), std::move(body), std::move(ret_type), + /*ty_params=*/{}, DictAttrs(attrs)); + body = Call(extracted_, std::move(args_)); + } else { + // Don't do anything with the inputs. + extracted_ = body; + } + + // Setup the output substitution. + for (const auto& kv : expr_to_output_index_) { + Expr expr; + if (outputs_.size() == 1) { + expr = body; + } else if (for_function) { + expr = TupleGetItem(body, kv.second); + } else { + const auto* tuple_node = body.as(); + ICHECK(tuple_node); + expr = tuple_node->fields[kv.second]; + } + VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index " + << kv.second << " (of " << outputs_.size() << " outputs)"; + output_substitution_.emplace(kv.first, std::move(expr)); + } + } + + ////// Following members are valid only after Extract() has returned. + + /*! + * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is + * defined then will be a function. + */ + Expr extracted() const { return extracted_; } + + /*! + * \brief Returns the substitution to apply to all expression nodes in the overall expression + * so as to replace references to outputs of the sub-graph with their rewritten form. + */ + const std::unordered_map& output_substitution() const { + return output_substitution_; + } + + private: + /*! + * \brief Returns a map from original index to new index for each node inside the sub-graph. Only + * valid after \p Extract has made its backwards dataflow sweep. + */ + IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const { + VLOG(2) << "building extractor substitution"; + IndexSubst subst; + for (PostDfsIndex index : sub_graph_->inside_) { + auto orig_node = dataflow_graph_->index_to_node(index); + ICHECK_EQ(orig_node->index_, index); + auto itr = memo_.find(orig_node->ref()); + ICHECK(itr != memo_.end()); + auto new_node = new_dataflow_graph.item_to_node(itr->second); + VLOG(2) << orig_node->index_ << " |-> " << new_node->index_; + subst.emplace(orig_node->index_, new_node->index_); + } + return subst; + } + + /*! \brief Returns true if \p expr is inside the sub-graph. */ + bool inside(const Expr& expr) { + return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_]; + } + + /*! + * \brief Returns the variable uniquely representing \p expr, which should be + * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph). + * + * It is valid for: + * - An expression outside the sub-graph to be used multiple times inside the sub-graph. + * - An expression outside the sub-graph to be used both inside and outside the sub-graph. + */ + Var VarFor(const Expr& expr) { + ICHECK(!inside(expr)); + ICHECK(opt_attrs_.defined()); + auto itr = expr_to_param_.find(expr.get()); + if (itr != expr_to_param_.end()) { + return itr->second; + } + // Ok if checked type is null here. + auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type_); + params_.push_back(fresh_var); + args_.push_back(expr); + expr_to_param_.emplace(expr.get(), fresh_var); + return fresh_var; + } + + /*! + * \brief If \p expr is inside the sub-graph then return it's rewritten form. + * If \p expr is outside the sub-graph then it must correspond to an input node. + * - If opt_attrs_ is defined return the variable to represent it. + * - Otherwise just return the expression directly. + * + * Should be called only on inputs to nodes which are inside the sub-graph. + */ + Expr VisitExpr(const Expr& expr) final { + if (inside(expr)) { + return ExprMutator::VisitExpr(expr); + } else if (CanInline(expr)) { + // Implicitly include inlinable input sub-expressions. + return expr; + } else if (opt_attrs_.defined()) { + // Map to a function parameter. + return VarFor(expr); + } else { + // Stop rewriting. + return expr; + } + } + + Expr VisitExpr_(const FunctionNode* function_node) override { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + return ExprMutator::VisitExpr_(function_node); + } + + //// Context fields, passed in constructor. + + /*! \brief The dataflow graph corresponding to the overall expression. */ + const DataflowGraph* dataflow_graph_; + /*! \brief Where to get "global_symbols" names. */ + NameSupply* name_supply_; + /*! \brief The sub-graph of the above we are extracting. */ + const SubGraphNode* sub_graph_; + /*! \brief Optional attributes if the sub-graph should be extracted as a function. */ + FunctionAttrsMap opt_attrs_; + + //// Result fields, available after Extract() called. + + /*! + * \brief The extracted expression. If opt_attrs_ is defined this will be a function. + */ + Expr extracted_; + /*! + * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than + * one exit node then each entry will be a tuple projection. + */ + std::unordered_map output_substitution_; + + //// Accumulator fields, built as we visit expressions. + + /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */ + Array params_; + /*! + * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_. + */ + Array args_; + /*! + * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters + * in params_ which now representing them. + */ + std::unordered_map expr_to_param_; + /*! + * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph. + * It is possible to have multiple outputs. It is possible one output also contributes to other + * outputs (ie the output is a 'tap'). + */ + std::vector outputs_; + /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */ + std::vector output_types_; + /*! + * \brief Map from existing exit expression nodes to the index in outputs_ which should + * represent them in the rewritten overall expression. + */ + std::unordered_map expr_to_output_index_; +}; + +Expr Rewriter::VisitExpr(const Expr& expr) { + auto itr = extractor_->output_substitution().find(expr.get()); + if (itr == extractor_->output_substitution().end()) { + return ExprMutator::VisitExpr(expr); + } else { + return itr->second; + } +} + +} // namespace + +std::pair SubExprKindAndLabel(const Expr& sub_expr) { + class Visitor : public ExprFunctor(const Expr&)> { + private: + std::pair VisitExpr_(const CallNode* call_node) final { + if (const auto* op_node = call_node->op.as()) { + auto op = GetRef(op_node); + static auto fpattern = Op::GetAttrMap("TOpPattern"); + if (fpattern.count(op) == 0) { + VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque"; + return {kOpaque, op->name}; + } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) { + VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque"; + return {kOpaque, op->name}; + } else { + OpPatternKind kind = static_cast(fpattern[op]); + VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind); + return {kind, op->name}; + } + } else if (const auto* function_node = call_node->op.as()) { + Optional opt_i = + function_node->GetAttr("TOpPattern", Optional()); + if (opt_i.defined()) { + OpPatternKind kind = static_cast(opt_i.value()->value); + VLOG(1) << "TOpPattern for function is " << KindToString(kind); + return {kind, "call_prim"}; + } else { + VLOG(1) << "calling function without TOpPattern, considering opaque"; + return {kOpaque, "call_fun"}; + } + } else { + VLOG(1) << "unsupported call, considering opaque"; + return {kOpaque, "call_any"}; + } + } + + std::pair VisitExpr_(const ConstantNode* constant_node) final { + VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise); + if (IsSimpleScalar(constant_node)) { + return {kElemWise, "scalar"}; + } else { + return {kElemWise, "const"}; + } + } + + std::pair VisitExpr_(const TupleNode* tuple_node) final { + const auto* tuple_type_node = tuple_node->checked_type().as(); + ICHECK(tuple_type_node != nullptr); + if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(), + [](const Type& type) { return type.as() != nullptr; })) { + VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective); + return {kInjective, "tuple"}; + } else { + VLOG(1) << "tuple contains non-tensors, considering opaque"; + return {kOpaque, "tuple"}; + } + } + + std::pair VisitExpr_( + const TupleGetItemNode* tuple_get_item_node) final { + const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as(); + ICHECK(tuple_type_node != nullptr); + if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(), + [](const Type& type) { return type.as() != nullptr; })) { + VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective); + return {kInjective, "proj"}; + } else { + VLOG(1) << "tuple being projected contains non-tensors, considering opaque"; + return {kOpaque, "proj"}; + } + } + + // TODO(mbs): We implement the following mostly so we have a lightweight way of describing + // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj + // sub-language we should revise the returned operator kinds to match. + + std::pair VisitExpr_(const VarNode* var_node) final { + return {kOpaque, "%" + var_node->name_hint()}; + } + std::pair VisitExpr_(const GlobalVarNode* global_var_node) final { + return {kOpaque, "@" + global_var_node->name_hint}; + } + std::pair VisitExpr_(const OpNode* op_node) final { + return {kOpaque, "`" + op_node->name}; + } + std::pair VisitExpr_(const FunctionNode* function_node) final { + return {kOpaque, "fn"}; + } + std::pair VisitExpr_(const LetNode* let_node) final { + return {kOpaque, "let"}; + } + std::pair VisitExpr_(const IfNode* if_node) final { + return {kOpaque, "if"}; + } + std::pair VisitExpr_(const RefCreateNode* ref_create_node) final { + return {kOpaque, "ref"}; + } + std::pair VisitExpr_(const RefReadNode* op) final { + return {kOpaque, "ref_read"}; + } + std::pair VisitExpr_(const RefWriteNode* op) final { + return {kOpaque, "ref_write"}; + } + std::pair VisitExpr_(const ConstructorNode* op) final { + return {kOpaque, "`" + op->name_hint}; + } + std::pair VisitExpr_(const MatchNode* op) final { + return {kOpaque, "match"}; + } + }; + return Visitor().VisitExpr(sub_expr); +} + +std::pair SubGraphKindAndLabel(const DataflowGraph& dataflow_graph, + const IndexSet& inside) { + std::ostringstream os; + bool first = true; + OpPatternKind max_kind = kElemWise; + for (PostDfsIndex index : inside) { + OpPatternKind sub_kind; + std::string sub_label; + std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref()); + if (!sub_label.empty()) { + if (first) { + first = false; + } else { + os << "+"; + } + os << sub_label; + } + max_kind = CombineKinds(max_kind, sub_kind); + } + return {max_kind, os.str()}; +} + +IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) { + IndexSet result(matcher.size()); + for (const auto& kv : matcher.memo()) { + for (const auto& matched_sub_expr : kv.second) { + if (CanInline(matched_sub_expr)) { + // Trivial sub-expressions can just be included in the extracted function body + // when we construct it and don't need to be considered part of the sub-graph. + continue; + } + if (kv.first.as()) { + // Don't consider the expressions matched by a wildcard to be part of the sub-graph. + continue; + } + result.Add(matcher.expr_to_node(matched_sub_expr)->index_); + } + } + return result; +} + +std::string SubGraphConfig::ToString() const { + std::ostringstream os; + os << "{max_exits=" << max_exits; + os << ",allow_taps=" << allow_taps; + os << ",max_max_depth=" << max_max_depth; + os << "}"; + return os.str(); +} + +SubGraph SubSubGraphNode::sub_graph() const { return Downcast(sub_graph_obj_); } + +bool SubSubGraphNode::operator==(const SubSubGraphNode& that) const { + return *sub_graph().get() == *that.sub_graph().get(); +} + +bool SubSubGraphNode::operator<(const SubSubGraphNode& that) const { + return *sub_graph().get() < *that.sub_graph().get(); +} + +size_t SubSubGraphNode::hash() const { + size_t h = StructuralHash()(attrs_); + h ^= sub_graph()->hash() + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; +} + +std::string SubSubGraphNode::ToString() const { + std::ostringstream os; + os << "{sub_graph=" << sub_graph()->ToString(); + os << ",attrs=" << PrettyPrint(attrs_); + os << "}"; + return os.str(); +} + +Function SubSubGraphNode::Extract(const DataflowGraph& dataflow_graph, + NameSupply& name_supply) const { + Extractor extractor(&dataflow_graph, &name_supply, sub_graph().get(), attrs_); + extractor.Extract(); + return Downcast(extractor.extracted()); +} + +Expr SubSubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + NameSupply& name_supply) const { + Extractor extractor(&dataflow_graph, &name_supply, sub_graph().get(), attrs_); + extractor.Extract(); + Rewriter rewriter(&extractor); + return rewriter.VisitExpr(expr); +} + +SubSubGraph::SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs) { + auto data = runtime::make_object(); + data->sub_graph_obj_ = std::move(sub_graph); + data->attrs_ = std::move(attrs); + data_ = std::move(data); +} + +SubSubGraph SubSubGraph::Subst(const DataflowGraph& new_dataflow_graph, + const std::unordered_map& subst) const { + return SubSubGraph(get()->sub_graph().Subst(new_dataflow_graph, subst), get()->attrs_); +} + +/*static*/ +Expr SubSubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + std::vector sub_sub_graphs, + NameSupply& name_supply) { + // IMPORTANT: See the corresponding comment in SubGraph::ParallelRewrite. + std::sort(sub_sub_graphs.begin(), sub_sub_graphs.end(), + [](const SubSubGraph& left, const SubSubGraph& right) { + return left->sub_graph()->last_inside_index_ > right->sub_graph()->last_inside_index_; + }); + + Expr result = expr; + for (const auto& sub_sub_graph : sub_sub_graphs) { + result = sub_sub_graph->Rewrite(dataflow_graph, result, name_supply); + } + return result; +} + +IndexSet SubGraphNode::Downstream(const DataflowGraph& dataflow_graph) const { + IndexSet downstream(dataflow_graph.size()); + for (PostDfsIndex exit_index : exit_) { + downstream = downstream | dataflow_graph.downstream_of(exit_index); + } + return downstream; +} + +bool SubGraphNode::IsValid(const DataflowGraph& dataflow_graph, + const SubGraphConfig& config) const { + // Check we don't have too many exit nodes. + if (config.max_exits > 0 && exit_.PopCount() > config.max_exits) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: " << exit_.PopCount() + << " exits exceeds maximum " << config.max_exits; + return false; + } + + // Check the maximum path depth is in limit. + if (config.max_max_depth > 0 && max_depth_ > config.max_max_depth) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: maximum depth " << max_depth_ + << " exceeds limit " << config.max_max_depth; + return false; + } + + // All inside nodes must be in the same basic block. + const DataflowGraph::Node* basic_block = nullptr; + for (PostDfsIndex index : inside_) { + auto node = dataflow_graph.index_to_node(index); + if (basic_block == nullptr) { + basic_block = node->basic_block_; + } + if (node->basic_block_ != basic_block) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: nodes are from different basic blocks"; + return false; + } + } + + // The sub-sub-graphs must be subsets and non-overlapping. + IndexSet union_inside(dataflow_graph.size()); + for (const auto& sub_sub_graph : sub_sub_graphs_) { + if (!sub_sub_graph->sub_graph()->inside_.AreDisjoint(union_inside)) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: sub-sub-graphs overlap"; + return false; + } + if (!sub_sub_graph->sub_graph()->inside_.IsSubset(inside_)) { + VLOG(1) << "Subgraph " << ToString() + << " is invalid: sub-sub-graph is not subset of overall sub-graph"; + return false; + } + } + + if (!config.allow_taps) { + // Exit nodes cannot also contribute to inside nodes. + for (PostDfsIndex index : exit_) { + auto node = dataflow_graph.index_to_node(index); + if (AnyOutputInside(node)) { + VLOG(1) << "Subgraph " << ToString() + << " is invalid: inner node is 'tapped' and also contributes to output, but taps " + "are disabled"; + return false; + } + } + } + + // Check no output would end up feeding into any entry node. + for (PostDfsIndex output_index : output_) { + if (dataflow_graph.downstream_of(output_index).Intersects(entry_)) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: output node " << output_index + << " feeds back into this sub-graph"; + return false; + } + } + + // Looks legit! + return true; +} + +Function SubGraphNode::ExtractAsFunction(const DataflowGraph& dataflow_graph, + NameSupply& name_supply) const { + SubSubGraph sub_sub_graph(GetRef(this), FunctionAttrsMap()); + return sub_sub_graph->Extract(dataflow_graph, name_supply); +} + +Expr SubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + NameSupply& name_supply) const { + if (sub_sub_graphs_.empty()) { + // Nothing to rewrite. + return expr; + } + Extractor extractor(&dataflow_graph, &name_supply, this, NullValue()); + extractor.Extract(); + Rewriter rewriter(&extractor); + return rewriter.VisitExpr(expr); +} + +std::string SubGraphNode::ToString() const { + std::ostringstream os; + os << "{inside=" << inside_.ToString(); + os << ",entry=" << entry_.ToString(); + os << ",exit=" << exit_.ToString(); + os << ",input=" << input_.ToString(); + os << ",output=" << output_.ToString(); + os << ",max_depth=" << max_depth_; + os << ",kind=" << KindToString(kind_); + if (!label_.empty()) { + os << ",label=" << label_; + } + for (const auto& sub_sub_graph : sub_sub_graphs_) { + os << ",sub_sub_graph=" << sub_sub_graph->ToString(); + } + os << "}"; + return os.str(); +} + +bool SubGraphNode::operator==(const SubGraphNode& that) const { + ICHECK_EQ(inside_.end_index(), that.inside_.end_index()); + if (inside_ != that.inside_) { + return false; + } + if (sub_sub_graphs_.size() != that.sub_sub_graphs_.size()) { + return false; + } + for (size_t i = 0; i < sub_sub_graphs_.size(); ++i) { + if (*sub_sub_graphs_[i].get() != *that.sub_sub_graphs_[i].get()) { + return false; + } + } + return true; +} + +bool SubGraphNode::operator<(const SubGraphNode& that) const { + if (first_inside_index_ < that.first_inside_index_) { + return true; + } + if (that.first_inside_index_ < first_inside_index_) { + return false; + } + return inside_ < that.inside_; } + +size_t SubGraphNode::hash() const { + size_t h = inside_.hash(); + for (const auto& sub_sub_graph : sub_sub_graphs_) { + h ^= sub_sub_graph->hash() + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; +} + +void SubGraphNode::Init(const DataflowGraph& dataflow_graph) { + for (PostDfsIndex index = 0; index < inside_.end_index(); ++index) { + auto node = dataflow_graph.index_to_node(index); + if (inside_[index]) { + if (AnyInputOutside(node)) { + entry_.Add(index); + } + if (AnyOutputOutside(node) || node->is_external_) { + exit_.Add(index); + } + } else { + if (AnyInputInside(node)) { + output_.Add(index); + } + if (AnyOutputInside(node) && !CanInline(node->ref())) { + input_.Add(index); + } + } + } + max_depth_ = MaxDepth(dataflow_graph); +} + +size_t SubGraphNode::MaxDepth(const DataflowGraph& dataflow_graph) const { + std::unordered_map max_depths; + std::vector stack; + size_t max_depth = 0; + // All the entry nodes have max depth 0. + for (PostDfsIndex index : entry_) { + auto node = dataflow_graph.index_to_node(index); + max_depths.emplace(node, 0); + stack.push_back(node); + } + while (!stack.empty()) { + const DataflowGraph::Node* node = stack.back(); + stack.pop_back(); + size_t next_depth = max_depths[node] + 1; + if (exit_[node->index_]) { + // If this node is external then it will have no outputs but we still wish to consider + // the path to the implied output as requiring one more step. + // Otherwise we're accounting for reaching one of the external outputs belowe. + max_depth = std::max(max_depth, next_depth); + } + for (const DataflowGraph::Node* output_node : node->outputs_) { + if (!inside_[output_node->index_]) { + continue; + } + if (max_depths.count(output_node) == 0) { + max_depths.emplace(output_node, next_depth); + stack.push_back(output_node); + } else if (next_depth > max_depths[output_node]) { + // We found a deeper path to an already expanded node. We'll expand again. + max_depths[output_node] = next_depth; + stack.push_back(output_node); + } + } + } + return max_depth; +} + +/*! \brief Return's true if any (input/output) of node is (outside/inside) the sub-graph. */ +bool SubGraphNode::AnyInputOutside(const DataflowGraph::Node* node) const { + return std::any_of(node->inputs_.begin(), node->inputs_.end(), + [this](const DataflowGraph::Node* sub_node) { + return !inside_[sub_node->index_] && !CanInline(sub_node->ref()); + }); +} + +bool SubGraphNode::AnyInputInside(const DataflowGraph::Node* node) const { + return std::any_of( + node->inputs_.begin(), node->inputs_.end(), + [this](const DataflowGraph::Node* sub_node) { return inside_[sub_node->index_]; }); +} + +bool SubGraphNode::AnyOutputOutside(const DataflowGraph::Node* node) const { + return std::any_of( + node->outputs_.begin(), node->outputs_.end(), + [this](const DataflowGraph::Node* sub_node) { return !inside_[sub_node->index_]; }); +} + +bool SubGraphNode::AnyOutputInside(const DataflowGraph::Node* node) const { + return std::any_of( + node->outputs_.begin(), node->outputs_.end(), + [this](const DataflowGraph::Node* sub_node) { return inside_[sub_node->index_]; }); +} + +SubGraph::SubGraph(const DataflowGraph& dataflow_graph, IndexSet inside, OpPatternKind kind, + String label, std::vector sub_sub_graphs) { + std::sort( + sub_sub_graphs.begin(), sub_sub_graphs.end(), + [](const SubSubGraph& left, const SubSubGraph& right) { return *left.get() < *right.get(); }); + auto node = runtime::make_object(); + node->inside_ = std::move(inside); + node->first_inside_index_ = node->inside_.FirstInsideIndex(); + node->last_inside_index_ = node->inside_.LastInsideIndex(); + node->entry_ = IndexSet(node->inside_.end_index()); + node->exit_ = IndexSet(node->inside_.end_index()); + node->input_ = IndexSet(node->inside_.end_index()); + node->output_ = IndexSet(node->inside_.end_index()); + node->kind_ = kind; + node->label_ = std::move(label); + node->sub_sub_graphs_ = sub_sub_graphs; + node->Init(dataflow_graph); + data_ = std::move(node); +} + +SubGraph::SubGraph(const DataflowGraph& dataflow_graph) + : SubGraph(dataflow_graph, IndexSet(dataflow_graph.size())) {} + +bool SubGraph::AreDisjoint(const SubGraph& that) const { + return get()->inside_.AreDisjoint(that->inside_); +} + +namespace { +/*! \brief Returns true if an output of \p left not in \p right ultimately flows into \p right. */ +bool FlowsInto(const DataflowGraph& dataflow_graph, const SubGraph& left, const SubGraph& right) { + for (PostDfsIndex output_index : left->output_) { + if (!right->inside_[output_index] && + dataflow_graph.downstream_of(output_index).Intersects(right->entry_)) { + return true; + } + } + return false; +} +} // namespace + +bool SubGraph::AreTouching(const DataflowGraph& dataflow_graph, const SubGraph& that) const { + if (!get()->inside_.AreDisjoint(that->inside_)) { + // Easy rejection. + return false; + } + if (!get()->output_.Intersects(that->entry_)) { + // Not touching. + return false; + } + if (FlowsInto(dataflow_graph, *this, that) || FlowsInto(dataflow_graph, that, *this)) { + // Unioning would create a cycle. + return false; + } + return true; +} + +bool SubGraph::AreSelfContained(const SubGraph& that) const { + return get()->output_.IsSubset(that->entry_) && that->input_.IsSubset(get()->exit_); +} + +SubGraph SubGraph::DisjointUnion(const DataflowGraph& dataflow_graph, const SubGraph& that) const { + ICHECK(AreDisjoint(that)); + IndexSet inside = get()->inside_ | that->inside_; + std::vector sub_sub_graphs; + for (const auto& sub_sub_graph : get()->sub_sub_graphs_) { + sub_sub_graphs.push_back(sub_sub_graph); + } + for (const auto& sub_sub_graph : that->sub_sub_graphs_) { + sub_sub_graphs.push_back(sub_sub_graph); + } + return SubGraph(dataflow_graph, std::move(inside), CombineKinds(get()->kind_, that->kind_), + UnionLabels(get()->label_, that->label_), std::move(sub_sub_graphs)); +} + +SubGraph SubGraph::WithAttrs(const DataflowGraph& dataflow_graph, FunctionAttrsMap attrs) const { + std::vector sub_sub_graphs; + sub_sub_graphs.push_back(SubSubGraph(*this, attrs)); + return SubGraph(dataflow_graph, get()->inside_, get()->kind_, get()->label_, + std::move(sub_sub_graphs)); +} + +SubGraph SubGraph::Subst(const DataflowGraph& new_dataflow_graph, const IndexSubst& subst) const { + IndexSet new_inside = get()->inside_.Subst(new_dataflow_graph.size(), subst); + std::vector new_sub_sub_graphs; + for (const auto& sub_sub_graph : get()->sub_sub_graphs_) { + new_sub_sub_graphs.push_back(sub_sub_graph.Subst(new_dataflow_graph, subst)); + } + return SubGraph(new_dataflow_graph, std::move(new_inside), get()->kind_, get()->label_, + std::move(new_sub_sub_graphs)); +} + +/*static*/ +Expr SubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + std::vector sub_graphs, NameSupply& name_supply) { + // IMPORTANT: + // - All the sub-graphs will be w.r.t. the dataflow graph for the original expression. + // Each time we call Rewrite on one of those graphs the result expression will be rewritten + // from the final output back to the inputs. The inputs will then be shared with the original + // expression. Thus it is safe to iteratively rewrite all the sub-graphs without redoing the + // dataflow_graph and substituting indexes provided we work in reverse dataflow order. + // - We rely on the argument expression reference holding the original expression alive so that + // the dataflow_graph will never contain dangling pointes (even though as per above we'll + // never dereference them). + std::sort(sub_graphs.begin(), sub_graphs.end(), [](const SubGraph& left, const SubGraph& right) { + return left->last_inside_index_ > right->last_inside_index_; + }); + Expr result = expr; + for (const auto& sub_graph : sub_graphs) { + result = sub_graph->Rewrite(dataflow_graph, result, name_supply); + } + return result; +} + +transform::Pass PartitionOnIndexesForTesting(size_t max_exits, bool allow_taps, + Array indexes, Array labels) { + auto pass_func = [=](Function function, IRModule mod, transform::PassContext ctxt) { + ICHECK(!labels.defined() || indexes.size() == labels.size()); + VLOG(1) << "Considering partitioning for:\n" << PrettyPrint(function); + DataflowGraph dataflow_graph(function); + std::unordered_map> sub_sub_graph_indexes; + std::vector node_indexes; + node_indexes.reserve(indexes.size()); + for (size_t i = 0; i < indexes.size(); ++i) { + const Integer& index = indexes[i]; + ICHECK_GE(index->value, 0); + ICHECK_LT(index->value, dataflow_graph.size()); + PostDfsIndex index_int = static_cast(index->value); + node_indexes.push_back(index_int); + if (labels.defined()) { + const String& label = labels[i]; + if (!label.empty()) { + sub_sub_graph_indexes[label].push_back(index_int); + } + } + } + std::vector sub_sub_graphs; + for (const auto& kv : sub_sub_graph_indexes) { + FunctionAttrsMap attrs; + attrs.Set("Composite", kv.first); + sub_sub_graphs.push_back( + SubSubGraph(SubGraph(dataflow_graph, IndexSet(dataflow_graph.size(), kv.second)), attrs)); + } + OpPatternKind kind; + String label; + IndexSet inside(dataflow_graph.size(), node_indexes); + std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label), + std::move(sub_sub_graphs)); + SubGraphConfig config; + config.max_exits = max_exits; + config.allow_taps = allow_taps; + if (sub_graph->IsValid(dataflow_graph, config)) { + VLOG(1) << "Sub-graph " << sub_graph->ToString() << " is considered valid"; + } else { + VLOG(1) << "Sub-graph " << sub_graph->ToString() + << " is NOT considered valid, not partitioning"; + return function; + } + NameSupply name_supply("test"); + Function result = Downcast(sub_graph->Rewrite(dataflow_graph, function, name_supply)); + VLOG(1) << "Partitioned to:\n" << PrettyPrint(result); + return result; + }; + return transform::CreateFunctionPass(pass_func, /*opt_level=*/0, "PartitionOnIndexesForTesting", + {}); +} + +TVM_REGISTER_GLOBAL("relay.collage.partition_on_indexes_for_testing") + .set_body_typed([](size_t max_outputs, bool allow_taps, Array indexes, + Array labels) { + return PartitionOnIndexesForTesting(max_outputs, allow_taps, indexes, labels); + }); + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/sub_graph.h b/src/relay/collage/sub_graph.h new file mode 100644 index 0000000000000..b6b5a9cd9baac --- /dev/null +++ b/src/relay/collage/sub_graph.h @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/sub_graph.h + * \brief Represents a sub-graph of an overall Relay expression. + */ + +#ifndef SRC_RELAY_COLLAGE_SUB_GRAPH_H_ +#define SRC_RELAY_COLLAGE_SUB_GRAPH_H_ + +#include + +#include + +#include "../ir/dataflow_matcher_impl.h" +#include "../ir/indexed_graph.h" +#include "./index_set.h" +#include "dataflow_graph.h" +#include "name_supply.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! \brief Returns operator pattern kind as single-letter string. */ +std::string KindToString(OpPatternKind kind); + +/*! + * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions. + */ +std::pair SubExprKindAndLabel(const Expr& sub_expr); + +/*! + * \brief Returns a kind and label for all the nodes in \p inside. + */ +std::pair SubGraphKindAndLabel(const DataflowGraph& dataflow_graph, + const IndexSet& inside); + +/*! + * \brief Returns the index set representing all the sub-expression matched by \p matcher. + */ +IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher); + +/*! + * \brief Configuration controlling which sub-graphs are considered valid. + */ +struct SubGraphConfig { + /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */ + size_t max_exits = 0; + /*! + * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside + * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs + * even with this flag false. + */ + bool allow_taps = false; + /*! + * \brief Maximum allowed maximum depth, or zero if no-limit. + */ + size_t max_max_depth = 0; + + std::string ToString() const; +}; + +class SubGraph; +using FunctionAttrsMap = Map; + +/*! + * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some + * enclosing sub-graph. + * + * Extraction yields a function with input nodes replaced by parameters and exit nodes in the + * function result. Rewriting replaces the sub-graph with a call to that function, and all + * outputs with (projections from) the call result. + * + * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class. + * However we found the implementation was easier to understand in this form since it makes + * the result of \p Extract unambiguous.) + */ +class SubSubGraphNode : public Object { + public: + /*! \brief The nested sub-graph. */ + ObjectRef /* actually SubGraph */ sub_graph_obj_; + /*! \brief Attributes (possibly empty) to attach to the extracted function. */ + FunctionAttrsMap attrs_; + + SubGraph sub_graph() const; + + bool operator==(const SubSubGraphNode& that) const; + bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); } + bool operator<(const SubSubGraphNode& that) const; + size_t hash() const; + + std::string ToString() const; + + /*! + * \brief Returns the function representing this sub-sub-graph within the overall expression + * represented by \p dataflow_graph: + * - All sub-graph inputs become parameters. + * - All sub-graph outputs become function results (either directly or as a field in a tuple). + * - The function has attrs_ for attributes (which may be empty). + * - The function body accounts for any rewrites implied by the nested sub-graph. + */ + Function Extract(const DataflowGraph& dataflow_graph, NameSupply& name_supply) const; + + /*! + * \brief Returns \p expr (which has matching \p dataflow_graph) rewritten to encode the + * partitioning implied by this sub-sub-graph. + */ + Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + NameSupply& name_supply) const; + + static constexpr const char* _type_key = "relay.collage.SubSubGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object); +}; + +class SubSubGraph : public ObjectRef { + public: + SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs); + + /*! + * \brief Returns copy of this sub-sub-graph with all indexes substituted according to \p subst, + * whose range is w.r.t. \p new_dataflow_graph. + */ + SubSubGraph Subst(const DataflowGraph& new_dataflow_graph, + const std::unordered_map& subst) const; + + /*! + * \brief Returns \p expr rewritten according to all the given sub-sub-graphs. The sub-sub-graphs + * can be given in any order, but must be disjoint. + */ + static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + std::vector sub_sub_graphs, NameSupply& name_supply); + + TVM_DEFINE_OBJECT_REF_METHODS(SubSubGraph, ObjectRef, SubSubGraphNode); +}; + +using SubSubGraphs = Array; + +/*! + * \brief A compact representation of a sub-graph within an (implied) overall Relay expression. + * + * Sub-graphs can be used to represent partitions/kernels/composite functions without having to + * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a + * function to use for measuring a partition/kernel's latency independently from 'rewriting' + * the overall Relay expression since only a tiny subset of candidate partitions will end up being + * needed after Collage has completed its search. + * + * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so are + * mindful of space overhead. + * + * A sub-graph classifies every dataflow node of the overall expression as either 'inside' or + * 'outside' the sub-graph. Obviously not all such divisions make sense, for example it is not + * valid for an inside node to feed into another inside node via outside nodes. We provide the + * \p IsValid method to check for validity, and \p SubGraphConfig to control which validity rules + * apply (such as maximum depth). + * + * We generally work with the \p DataflowGraph representation of the overall Relay expression + * rather than the expression itself. We use the post-dfs visit index to uniquely refer to + * expression nodes. + * + * As well as 'inside' and 'outside' we have four other flavors of dataflow nodes, all uniquely + * determined from the 'inside' nodes: + * - 'entry' nodes are those inside with at least one dataflow input outside. + * - 'exit' nodes are those inside with at least one dataflow output outside, or which + * are considered 'external' in the underlying dataflow graph (eg because they represent + * the result of the overall function). + * - 'input' nodes are those outside with at least one dataflow output inside. + * - 'output' nodes are those outside with at least one dataflow input inside. + * Index sets for these are cached with the sub-graph for performance. + * + * It is valid to have multiple entry nodes (we can bind a parameter for each). It may be valid to + * have multiple exit nodes (we can build a tuple of all such). It may be valid to have exit nodes + * which also contribute to other inside nodes (ie represent a 'tap' on an intermediate result). + * + * Sub-graphs are closed under: + * - Disjoint union. + * - Wrapping by a function with given attributes (see \p SubSubGraph above). This can be used + * to encode "Composite" functions, or to represent a candidate kernel within a "Primitive" + * function. (By combining 'wrapping' with 'union' we can encode, eg, 'this sub-graph should + * be placed inside a primitive function which itself may have calls to composite functions). + * - Substitution, which allows a sub-graph w.r.t. one dataflow graph to be transformed to + * match some other (typically smaller) dataflow graph. + * + * See the subclasses of \p PartitionRule for how sub-graphs are built and combined during Collage + * search. + * + * To support some of the \p OpPatternKind-based fusion rule processing we give sub-graphs + * a kind, which is generally the maximum of the kinds of all the operator calls appearing + * inside it. We also given sub-graphs a (not necessarily unique) label to help debugging + * and guide the selection of global symbol names. + */ +class SubGraphNode : public Object { + public: + /*! + * \brief Which sub-expressions are inside the sub-graph (using their post-dfs indexes w.r.t. + * the implied DataflowGraph). + */ + IndexSet inside_; + + /*! + * \brief Index of first and last inside nodes. + * + * Cached for performance, uniquely determined by inside_. + */ + PostDfsIndex first_inside_index_ = 0; + PostDfsIndex last_inside_index_ = 0; + + /*! + * \brief Which sub-expressions are entry/exit/input/output for this sub-graph. + * + * Cached for performance, uniquely determined by inside_. + */ + IndexSet entry_; + IndexSet exit_; + IndexSet input_; + IndexSet output_; + + /*! + * \brief Maximum depth of any dataflow path from an entry to an output sub-expression. + * + * Cached for performance, uniquely determined by inside_. + */ + size_t max_depth_ = 0; + + /*! + * \brief The \p OpPatternKind summarizing the input/output behavior of the sub-graph. + * + * A sub-graph consisting of a single Relay expression node is given kind: + * - For Call to a Relay operator, the "TOpPattern" attribute of that operator (provided the + * call does not involve data-dependent dynamic shapes). + * - For Call to Relay Function, the "TOpPattern" attribute of the function (provided it has + * that attribute) + * - For Constants, \p kElemWise. + * - For Tuple and tuple projections, \p kInjective (provided all tuple fields are of tensor + * type) + * - All other nodes \p kOpaque. + * Sub-graphs with more than one node have the maximum of the kind of each node. + * + * Cached for performance, uniquely determined by inside_. + */ + OpPatternKind kind_ = kOpaque; + + /*! + * \brief A label for the sub-graph. Not guaranteed to be unique, but is a human-readable summary + * of the sub-graph which can help with debugging and guide the selection of global symbol names. + */ + String label_; + + /*! + * \brief Sub-sub-graphs of this sub-graph which must be represented by functions. These must + * be disjoint, but it's ok for this sub-graph to have nodes not inside any sub-sub-graph. + */ + SubSubGraphs sub_sub_graphs_; + + // TODO(mbs): 'Anchor nodes' and rules for unioning them. + // In FuseOps it's just the unique kEWiseFusable node, if any. + // I'd like to allow writing vertical fusion rules, eg if two candidates are directly + // connected and have nn.conv2d anchors allow their join. + // I'd also like to allow horizontal fusion rules, eg if two candidates are not directly + // connected but could be joined without producing invalid (eg cyclic) and have nn.conv2d anchors + // then do so. Come back to this. + + /*! \brief Number of nodes in overall dataflow graph. */ + size_t overall_size() const { return inside_.end_index(); } + + bool IsEmpty() const { return inside_.IsZero(); }; + + /*! \brief Number of nodes in sub-graph. */ + size_t Size() const { return inside_.PopCount(); } + + /*! + * \brief Returns the dataflow nodes downstream of all exit nodes. + */ + IndexSet Downstream(const DataflowGraph& dataflow_graph) const; + + /*! + * \brief Returns true if this sub-graph is valid. Ie: + * - no output of the sub-graph can flow to any input of the sub-graph (otherwise we'd end up + * with a dataflow cycle when we partition). + * - all inputs and outputs of the sub-graph are in the same scope, ie not separated by + * control flow (otherwise there'd be no consistent program point at which to eval the + * partitioned function). + * - no more than config.max_outputs outputs are require. + * - if config.allow_taps is false, no inside node has outputs to nodes both inside and + * outside the sub-graph. + */ + bool IsValid(const DataflowGraph& dataflow_graph, const SubGraphConfig& config) const; + + /*! + * \brief Returns this sub-graph extracted as a stand-alone function. The function will have + * no attributes, and is suitable for building and profiling by the \p CostEstimator. + */ + Function ExtractAsFunction(const DataflowGraph& dataflow_graph, NameSupply& name_supply) const; + + /*! + * \brief Returns \p expr (which has matching \p dataflow_graph) rewritten to encode the + * partitioning implied by this sub-graph. + */ + Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + NameSupply& name_supply) const; + + std::string ToString() const; + + bool operator==(const SubGraphNode& that) const; + bool operator!=(const SubGraphNode& that) const { return !(*this == that); } + bool operator<(const SubGraphNode& that) const; + size_t hash() const; + + private: + /*! \brief Initialize the entry/exit/input/output sets given the inside and \p dataflow_graph. */ + void Init(const DataflowGraph& dataflow_graph); + + /*! \brief Calculates and returns the maximum path depth. */ + size_t MaxDepth(const DataflowGraph& dataflow_graph) const; + + /*! \brief Return's true if any (input/output) of node is (outside/inside) the sub-graph. */ + bool AnyInputOutside(const DataflowGraph::Node* node) const; + bool AnyInputInside(const DataflowGraph::Node* node) const; + bool AnyOutputOutside(const DataflowGraph::Node* node) const; + bool AnyOutputInside(const DataflowGraph::Node* node) const; + + public: + static constexpr const char* _type_key = "relay.collage.SubGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(SubGraphNode, Object); + + friend class SubGraph; +}; + +class SubGraph : public ObjectRef { + public: + /*! \brief Primitive constructor. The following constructors are generally more convenient. */ + SubGraph(const DataflowGraph& dataflow_graph, IndexSet inside, OpPatternKind kind = kOpaque, + String label = {}, std::vector sub_sub_graphs = {}); + + /*! \brief Constructs the empty sub-graph for \p dataflow_graph. */ + explicit SubGraph(const DataflowGraph& dataflow_graph); + + /*! \brief Returns true if this and that are disjoint. */ + bool AreDisjoint(const SubGraph& that) const; + + /*! + * \brief Returns true if: + * - \p this and \p that are disjoint, and + * - an output node of \p this coincides with an entry node of \p that, and + * - \p this and \p that are not obviously invalid after \p DisjointUnion + * (eg because such a sub-graph would produce a cycle). + * Note however that the \p DisjointUnion may not necessarily be valid even with the above + * checks. + */ + bool AreTouching(const DataflowGraph& dataflow_graph, const SubGraph& that) const; + + /*! + * \brief Returns true if: + * - all the outputs of \p this are entries for \p that, and + * - all the inputs of \p that are exits for \p this. + */ + bool AreSelfContained(const SubGraph& that) const; + + /*! + * \brief Returns disjoint union of this and \p that sub-graphs. The result may not be valid. + */ + SubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubGraph& that) const; + + /*! + * \brief Returns copy of this sub-graph with all nodes placed inside a sub-sub-graph with + * given attributes. + */ + SubGraph WithAttrs(const DataflowGraph& dataflow_graph, FunctionAttrsMap attrs) const; + + /*! + * \brief Returns copy of this sub-graph with all indexes substituted according to \p subst, + * whose range is w.r.t. \p new_dataflow_graph. + */ + SubGraph Subst(const DataflowGraph& new_dataflow_graph, + const std::unordered_map& subst) const; + + /*! + * \brief Returns \p expr rewritten according to all the given sub-graphs. The sub-graphs can + * be given in any order, but must be disjoint. + */ + static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + std::vector sub_graphs, NameSupply& name_supply); + + TVM_DEFINE_OBJECT_REF_METHODS(SubGraph, ObjectRef, SubGraphNode); +}; + +struct SubGraphEqual { + bool operator()(const SubGraph& left, const SubGraph& right) const { + return *left.get() == *right.get(); + } +}; + +struct SubGraphHash { + size_t operator()(const SubGraph& sub_graph) const { return sub_graph->hash(); } +}; + +/*! + * \brief Pass to partition every global function according to the post-dfs indexes + * given in an array. Visible for testing from Python only, would never make sense to use + * as a generic pass! + */ +transform::Pass PartitionOnIndexesForTesting(Array indexes); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_SUB_GRAPH_H_ diff --git a/src/relay/collage/utils.cc b/src/relay/collage/utils.cc new file mode 100644 index 0000000000000..e8517f12be3c4 --- /dev/null +++ b/src/relay/collage/utils.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/utils.cc + * \brief Misc helpers. + */ + +#include "./utils.h" + +#include "../op/memory/device_copy.h" + +namespace tvm { +namespace relay { +namespace collage { + +String UnionLabels(String left, String right) { + if (left.empty()) { + return right; + } + if (right.empty()) { + return left; + } + return left + "+" + right; +} + +String NestLabels(String left, String right) { + if (left.empty()) { + return right; + } + if (right.empty()) { + return left; + } + return left + "." + right; +} + +std::string KindToString(OpPatternKind kind) { + switch (kind) { + case kElemWise: + return "E"; + case kBroadcast: + return "B"; + case kInjective: + return "I"; + case kCommReduce: + return "R"; + case kOutEWiseFusable: + return "A"; + case kTuple: + return "T"; + case kOpaque: + return "O"; + } + return "?"; +} + +OpPatternKind CombineKinds(OpPatternKind left, OpPatternKind right) { + return std::max(left, right); +} + +bool IsSimpleScalar(const ConstantNode* constant_node) { + if (!constant_node->is_scalar()) { + return false; + } + DataType dtype = DataType(constant_node->data->dtype); + static DataType int32 = DataType::Int(32); + static DataType int64 = DataType::Int(64); + static DataType float32 = DataType::Float(32); + static DataType float64 = DataType::Float(64); + static DataType bool_ = DataType::Bool(); + return dtype == int32 || dtype == int64 || dtype == float32 || dtype == float64 || dtype == bool_; +} + +bool CanInline(const Expr& expr) { + if (expr.as() || expr.as() || expr.as()) { + return true; + } + if (const auto* constant_node = expr.as()) { + return IsSimpleScalar(constant_node); + } + return false; +} + +bool IsSpecialOp(const OpNode* op_node) { + auto op = GetRef(op_node); + static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); + if (fnoncomputational.count(op) && fnoncomputational[op]) { + // Operator has been marked as non-computational. + return true; + } + // TODO(mbs): This is incomplete. + static auto shape_of_op_ = Op::Get("shape_of"); + static auto vm_shape_of_op_ = Op::Get("vm.shape_of"); + if (op == DeviceCopyOp() || op == shape_of_op_ || op == vm_shape_of_op_) { + // Operator is compiled away by the VM compilation flow. + return true; + } + return false; +} + +bool MustBeLowered(const Expr& expr) { + if (const auto* call_node = expr.as()) { + if (const auto* function_node = call_node->op.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // We've already committed to this call being to one or more operators which must be + // lowered. + return true; + } + } else if (const auto* op_node = call_node->op.as()) { + if (!IsSpecialOp(op_node)) { + // The VM compilation path won't rewrite this call. + return true; + } + } + } + return false; +} + +} // namespace collage +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/collage/utils.h b/src/relay/collage/utils.h new file mode 100644 index 0000000000000..daa729ec31199 --- /dev/null +++ b/src/relay/collage/utils.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/utils.h + * \brief Misc helpers. + */ + +#ifndef SRC_RELAY_COLLAGE_UTILS_H_ +#define SRC_RELAY_COLLAGE_UTILS_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! \brief Returns \p "+". */ +String UnionLabels(String left, String right); + +/*! \brief Returns \p ".". */ +String NestLabels(String outer, String inner); + +/*! \brief Returns abbreviation for \p kind. */ +std::string KindToString(OpPatternKind kind); + +/*! \brief Returns maximum of \p left and \p right. */ +OpPatternKind CombineKinds(OpPatternKind left, OpPatternKind right); + +/*! + * \brief Returns true if \p constant_node is a float/int/bool scalar which is always safe to + * inline. + */ +bool IsSimpleScalar(const ConstantNode* constant_node); + +/*! + * \brief Returns true if \p expr can be safely inlined in body of function extracted + * from sub-graph, even if \p expr was not technically matched by the pattern which produced + * the sub-graph. + */ +bool CanInline(const Expr& expr); + +/*! + * \brief Returns true if \p op_node can be directly handled by the VM. + */ +bool IsSpecialOp(const OpNode* op_node); + +/*! + * \brief Return true if the Relay expression node given by \p expr cannot be evaluated by + * the VM and must end up in a kernel. + */ +bool MustBeLowered(const Expr& expr); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // SRC_RELAY_COLLAGE_UTILS_H_ diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 8d7ed163a1975..07919fce61fde 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -299,51 +299,66 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { auto call_node = expr.as(); - for (auto node : expr_graph_.node_map_.at(expr)->inputs_) { - if (!(call_node && node->ref_ == call_node->op)) { + auto index_node = expr_to_node(expr); + for (auto node : index_node->inputs_) { + VLOG_CONTEXT << "input " << node->index_; + if (!(call_node && node->ref() == call_node->op)) { memoize_ = true; - if (VisitDFPattern(op->parent, node->ref_)) { + if (VisitDFPattern(op->parent, node->ref())) { + VLOG(1) << "path matches parent pattern at " << node->index_; return true; } else { memoize_ = false; - if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) { + if (!VisitDFPattern(op->path, node->ref())) { + VLOG(1) << "path fails to match path pattern at " << node->index_; + return false; + } + if (!MatchesPath(op, node->ref())) { return false; } } } } + VLOG(1) << "visited all inputs from " << index_node->index_; return true; } // Iteratively ensure that the parent is dominated somewhere by the child or the path bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { std::stack stack; - std::unordered_set visited; + std::unordered_set visited; stack.push(expr); while (!stack.empty()) { Expr current = stack.top(); stack.pop(); - for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) { - if (visited.count(node->ref_) == 0) { - if (VisitDFPattern(op->parent, node->ref_)) { + for (auto node : expr_to_node(current)->dominator_children_) { + VLOG_CONTEXT << "child " << node->index_; + if (visited.count(node->node_ref_) == 0) { + if (VisitDFPattern(op->parent, node->ref())) { + VLOG(1) << "matches dominator child at " << node->index_; return true; } else { - stack.push(node->ref_); + stack.push(node->ref()); } - visited.insert(node->ref_); + visited.insert(node->node_ref_); } } } + VLOG(1) << "could not find dominator in children from " << expr_to_node(expr)->index_; return false; } bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { + VLOG_CONTEXT << "looking for dominator pattern match at " << expr_to_node(expr)->index_; if (VisitDFPattern(op->child, expr)) { + VLOG(1) << "matches child pattern"; bool matches_path = MatchesPath(op, expr); memoize_ = true; if (matches_path) { return DominatesParent(op, expr); } + } else { + VLOG(1) << "does not match child pattern"; } return false; } @@ -500,7 +515,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr } bool MatchPattern(DFPattern pattern, Expr expr) { - return DFPatternMatcher(expr).Match(pattern, expr); + std::unique_ptr> expr_graph = CreateIndexedGraph(expr); + return DFPatternMatcher(expr_graph.get()).Match(pattern, expr); } TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); @@ -575,7 +591,8 @@ const std::unordered_map& PatternGrouper::GroupMatch pattern_ = pattern; pattern_graph_ = CreateIndexedGraph(pattern_); - auto matcher = DFPatternMatcher(pre); + std::unique_ptr> expr_graph = CreateIndexedGraph(pre); + DFPatternMatcher matcher(expr_graph.get()); matcher_ = &matcher; this->VisitExprs(); return this->groups_; @@ -583,9 +600,9 @@ const std::unordered_map& PatternGrouper::GroupMatch void PatternGrouper::VisitExprs() { std::unordered_set pre_partitioned; - for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { - size_t index = i - 1; - Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; + for (PostDfsIndex i = matcher_->size(); i != 0; --i) { + PostDfsIndex index = i - 1; + const auto current = matcher_->index_to_node(index)->ref(); if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped if (auto op = current.as()) { if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { @@ -607,9 +624,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) { auto node_map = matcher_->GetMemo(); // Get fuzzy patterns std::unordered_set fuzzy_matches; - for (auto node : pattern_graph_.topological_order_) { + for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { + auto node = pattern_graph_->index_to_node(index); // Don't treat fuzzy Dominator patterns input variables for partition - if (auto op = node->ref_.as()) { + if (auto op = node->ref().as()) { for (auto fuzzy_op : {op->parent, op->path}) { for (auto match : node_map[fuzzy_op]) { fuzzy_matches.insert(match); @@ -617,12 +635,13 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } // Don't treat Function params or body as input variables for partition - if (node->ref_.as()) { - auto matches = node_map[node->ref_]; + if (node->ref().as()) { + auto matches = node_map[node->ref()]; for (auto match : matches) { - auto graph = CreateIndexedGraph(match.as()->body); - for (auto node : graph.topological_order_) { - fuzzy_matches.insert(node->ref_); + auto sub_graph = CreateIndexedGraph(match.as()->body); + for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) { + auto sub_node = sub_graph->index_to_node(sub_index); + fuzzy_matches.insert(sub_node->ref()); } } } @@ -636,10 +655,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { std::unordered_map inputs; Array params; - for (auto node : pattern_graph_.topological_order_) { + for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { + auto node = pattern_graph_->index_to_node(index); auto make_input = [&](const Expr& input) { if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && - input.as() == nullptr && !EmbedConst(input, node->ref_)) { + input.as() == nullptr && !EmbedConst(input, node->ref())) { inputs[input] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), NullValue()); @@ -648,11 +668,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { var_number++; } }; - auto tuple = node->ref_.as(); - auto call = node->ref_.as(); + auto tuple = node->ref().as(); + auto call = node->ref().as(); if (tuple && !tuple->fields.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { for (auto input : match.as()->fields) { make_input(input); @@ -660,8 +680,8 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } } else if (call && !call->args.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { for (auto input : match.as()->args) { make_input(input); @@ -669,8 +689,8 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } } else if (node->inputs_.size() == 0) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { make_input(match); } @@ -699,23 +719,39 @@ void PatternGrouper::CreateGroup(const Expr& expr) { // output would create an invalid graph tranformation, so we block the creation of such groups. auto memo = extractor.GetMemo(); for (auto kv : memo) { + VLOG(1) << "looking at matched:\n" << PrettyPrint(kv.first); // Check to ensure that this node isn't an input or a global if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && kv.first.as() == nullptr && kv.first.as() == nullptr) { if (gid_assignments_.count(kv.first) != 0) { // check to see if the node is use in other groups // Exit due to overlapping partitions + VLOG(1) << "matched used in another group"; return; - } else if (kv.second != body) { + } else if (kv.second == body) { + VLOG(1) << "matched rewritten to body"; + } else { // if the node isn't the output of the group - auto node = matcher_->expr_graph_.node_map_.at(kv.first); + auto node = matcher_->expr_to_node(kv.first); + VLOG(1) << "checking " << node->outputs_.size() << " outputs of matched"; for (auto* output : node->outputs_) { // and the node is used by nodes outside of the group - if (memo.count(output->ref_) == 0 && - !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { - // Exit because nodes in this pattern's body are used outside the pattern - // fusing it would be invalid - return; + if (memo.count(output->ref()) == 0) { + VLOG(1) << "output:\n" + << PrettyPrint(output->ref()) << "\nis not matched, checking against expr:\n" + << PrettyPrint(expr); + if (!matcher_->expr_to_node(expr)->Dominates(output)) { + // TODO(mbs): The dominates relation is backwards here, and will always fail. + // So all we're doing is failing if an internal node connects to an outside node. + // Exit because nodes in this pattern's body are used outside the pattern + // fusing it would be invalid + VLOG(1) << "does not dominate"; + return; + } else { + VLOG(1) << "dominates"; + } + } else { + VLOG(1) << "output:\n" << PrettyPrint(output->ref()) << "\nis also matched"; } } } diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h index d993d4720e4ed..bd66aa0903902 100644 --- a/src/relay/ir/dataflow_matcher_impl.h +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -39,10 +40,21 @@ namespace relay { class DFPatternMatcher : public DFPatternFunctor { public: - explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + explicit DFPatternMatcher(const IndexedGraph* expr_graph) : expr_graph_(expr_graph) {} bool Match(const DFPattern& pattern, const Expr& expr); Map> GetMemo() { return Map>(memo_); } - const IndexedGraph expr_graph_; + + const IndexedGraph::Node* expr_to_node(const Expr& expr) const { + return expr_graph_->item_to_node(expr); + } + const IndexedGraph::Node* index_to_node(size_t index) const { + return expr_graph_->index_to_node(index); + } + size_t size() const { return expr_graph_->size(); } + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& memo() const { + return memo_; + }; + const IndexedGraph& dataflow_graph() const { return *expr_graph_; } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -67,6 +79,7 @@ class DFPatternMatcher : public DFPatternFunctor* expr_graph_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; std::vector matched_nodes_; bool memoize_ = true; @@ -131,7 +144,7 @@ class PatternGrouper { std::unordered_map groups_; std::unordered_map gid_assignments_; DFPatternMatcher* matcher_ = nullptr; - IndexedGraph pattern_graph_; + std::unique_ptr> pattern_graph_; int gid_ = 0; int graph_number_ = 0; }; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index fc76577bd7c07..bb32c17a22463 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -27,6 +27,26 @@ namespace tvm { +GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint, Optional opt_type, + Optional opt_virtual_device, Optional opt_span) { + String name_hint = opt_name_hint.value_or(global_var->name_hint); + Type type = opt_type.value_or(global_var->checked_type()); + VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device()); + Span span = opt_span.value_or(global_var->span); + bool all_fields_unchanged = + name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) && + virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span); + if (!all_fields_unchanged) { + GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite(); + cow_global_var_node->name_hint = name_hint; + cow_global_var_node->checked_type_ = type; + cow_global_var_node->virtual_device_ = virtual_device; + cow_global_var_node->span = span; + } + + return global_var; +} + VirtualDevice RelayExprNode::virtual_device() const { if (!this->virtual_device_.defined()) { // virtual_device_ should always be defined, unless we imported this node from JSON using an old @@ -77,6 +97,25 @@ TensorType ConstantNode::tensor_type() const { return TensorType(shape, dtype); } +Constant WithFields(Constant constant, Optional opt_data, + Optional opt_virtual_device, Optional opt_span) { + runtime::NDArray data = opt_data.value_or(constant->data); + VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device()); + Span span = opt_span.value_or(constant->span); + + bool all_fields_unchanged = data.same_as(constant->data) && + virtual_device.same_as(constant->virtual_device()) && + span.same_as(constant->span); + + if (!all_fields_unchanged) { + ConstantNode* cow_constant_node = constant.CopyOnWrite(); + cow_constant_node->data = data; + cow_constant_node->virtual_device_ = virtual_device; + cow_constant_node->span = span; + } + return constant; +} + Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 4efe57b491db0..6da4f4a1adbd0 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -26,21 +26,19 @@ #include #include #include -#include +#include namespace tvm { namespace relay { -// IndexedGraph - -IndexedGraph CreateIndexedGraph(const Expr& expr) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Topological order */ +std::unique_ptr> CreateIndexedGraph(const Expr& expr) { + /*! \brief Creates an IndexedGraph and determines topological order */ class Creator : public MixedModeVisitor { public: - IndexedGraph CreateGraph(const Expr& expr) { + std::unique_ptr> CreateGraph(const Expr& expr) { + graph_ = std::make_unique>(); VisitExpr(expr); - graph_.node_map_[expr]->is_external_ = true; + graph_->item_to_node(expr)->is_external_ = true; return std::move(graph_); } @@ -49,165 +47,296 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { void VisitLeaf(const Expr& expr) override { MixedModeVisitor::VisitLeaf(expr); - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(expr); + } + + void VisitExpr_(const FunctionNode* function_node) override { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // Don't recurse into primitive functions. + return; + } + ExprVisitor::VisitExpr_(function_node); } - void VisitExpr_(const LetNode* let) override { + void VisitExpr_(const LetNode* let_node) override { auto pre_visit = [&](const LetNode* op) { - this->VisitSpan(op->span); - this->VisitExpr(op->value); - this->VisitExpr(op->var); + VisitExpr(op->value); + VisitExpr(op->var); }; auto post_visit = [&](const LetNode* op) { - this->VisitExpr(op->body); - if (let != op) { - Expr expr = GetRef(op); + VisitExpr(op->body); + if (let_node != op) { visit_counter_[op]++; - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(GetRef(op)); } }; - ExpandANormalForm(let, pre_visit, post_visit); + ExpandANormalForm(let_node, pre_visit, post_visit); + } + + class PatternCreator : public PatternVisitor { + public: + explicit PatternCreator(Creator* creator) : creator_(creator) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + creator_->graph_->AddNode(pattern_var_node->var); + } + + Creator* creator_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + VisitExpr(match_node->data); + for (const Clause& c : match_node->clauses) { + PatternCreator pattern_creator(this); + pattern_creator.VisitPattern(c->lhs); + VisitExpr(c->rhs); + } } - IndexedGraph graph_; - size_t index_ = 0; + std::unique_ptr> graph_; }; - /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does dominator tree + + /*! + * \brief Takes an IndexedGraph, fills it's forward outputs, and does dominator tree * analysis. * - * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined * topological order instead of recursing. */ - class Annotator : public ExprFunctor { + class Annotator : public ExprFunctor { public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { + explicit Annotator(std::unique_ptr> graph) : graph_(std::move(graph)) {} + + std::unique_ptr> Annotate() { // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - ExprFunctor::VisitExpr(node->ref_, nullptr); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + VisitExpr(graph_->index_to_node(index)->ref()); } // do the dominator analysis - graph_.PostDom(); + graph_->PostDom(); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + auto node = graph_->index_to_node(index); + if (node->dominator_parent_) { + VLOG(2) << "node index " << index << " has dominator parent index " + << node->dominator_parent_->index_; + } + } return std::move(graph_); } - /*! Default visitation pushes the parent to the child's outputs and the child to the parent's - * inputs*/ - void VisitExpr(const Expr& expr, NodePtr parent) override { - auto current = graph_.node_map_[expr]; - if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); - } + /*! + * \brief Add \p parent as a possible output of the node corresponding to \p expr. + */ + void AddOutput(const Expr& expr, IndexedGraph::Node* parent) { + auto current = graph_->item_to_node(expr); + current->outputs_.push_back(parent); + parent->inputs_.push_back(current); } protected: - IndexedGraph graph_; - void VisitExpr_(const VarNode* op, NodePtr parent) override { - if (op->type_annotation.defined()) { - this->VisitType(op->type_annotation); - } - } + void VisitExpr_(const VarNode* var_node) override {} - void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} + void VisitExpr_(const GlobalVarNode* global_var_node) override {} - void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} + void VisitExpr_(const ConstantNode* constant_node) override {} - void VisitExpr_(const TupleNode* op, NodePtr parent) override { - for (auto field : op->fields) { - this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const TupleNode* tuple_node) override { + auto node = graph_->item_to_node(GetRef(tuple_node)); + for (auto field : tuple_node->fields) { + AddOutput(field, node); } } - void VisitExpr_(const FunctionNode* op, NodePtr parent) override { - for (auto param : op->params) { - this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const FunctionNode* function_node) override { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No dataflow analysis inside primitive functions + return; } + auto node = graph_->item_to_node(GetRef(function_node)); + // Nothing to do for parameters -- each use of a parameter will contribute to its outputs. + AddOutput(function_node->body, node); + } + + void VisitExpr_(const CallNode* call_node) override { + auto node = graph_->item_to_node(GetRef(call_node)); + AddOutput(call_node->op, node); + for (auto arg : call_node->args) { + AddOutput(arg, node); + } + } + + void VisitExpr_(const LetNode* let_node) override { + auto node = graph_->item_to_node(GetRef(let_node)); + auto let_var_node = graph_->item_to_node(let_node->var); + AddOutput(let_node->value, let_var_node); + // Nothing to do for the let-bound variable -- each use of that variable in the let-body + // will contribute to its outputs. + AddOutput(let_node->body, node); + } + + void VisitExpr_(const IfNode* if_node) override { + auto node = graph_->item_to_node(GetRef(if_node)); + AddOutput(if_node->cond, node); + AddOutput(if_node->true_branch, node); + AddOutput(if_node->false_branch, node); + } + + void VisitExpr_(const OpNode* op_node) override {} + + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) override { + auto node = graph_->item_to_node(GetRef(tuple_get_item_node)); + AddOutput(tuple_get_item_node->tuple, node); + } - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const RefCreateNode* ref_create_node) override { + auto node = graph_->item_to_node(GetRef(ref_create_node)); + AddOutput(ref_create_node->value, node); } - void VisitExpr_(const CallNode* op, NodePtr parent) override { - this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const RefReadNode* ref_read_node) override { + auto node = graph_->item_to_node(GetRef(ref_read_node)); + AddOutput(ref_read_node->ref, node); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) override { + auto node = graph_->item_to_node(GetRef(ref_write_node)); + AddOutput(ref_write_node->ref, node); + AddOutput(ref_write_node->value, node); + } + + void VisitExpr_(const ConstructorNode* constructor_node) override {} - for (auto ty_arg : op->type_args) { - this->VisitType(ty_arg); + class PatternAnnotator : public PatternVisitor { + public: + PatternAnnotator(Annotator* annotator, const ExprNode* adt_node) + : annotator_(annotator), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto node = annotator_->graph_->item_to_node(pattern_var_node->var); + annotator_->AddOutput(GetRef(adt_node_), node); } - for (auto arg : op->args) { - this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + Annotator* annotator_; + const ExprNode* adt_node_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + // Data flows from the match data to pattern vars into match arms and out into overall match. + auto node = graph_->item_to_node(GetRef(match_node)); + for (const Clause& c : match_node->clauses) { + PatternAnnotator pattern_annotator(this, match_node->data.get()); + pattern_annotator.VisitPattern(c->lhs); + AddOutput(c->rhs, node); } } - void VisitExpr_(const LetNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); - } + std::unique_ptr> graph_; + }; + + /*! \brief Fills in the basic blocks for all nodes. */ + class Blocker : public MixedModeVisitor { + public: + explicit Blocker(std::unique_ptr> graph) : graph_(std::move(graph)) {} - void VisitExpr_(const IfNode* op, NodePtr parent) override { - this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + std::unique_ptr> Scope(const Expr& expr) { + VisitExpr(expr); + return std::move(graph_); } - void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } + private: + using MixedModeVisitor::VisitExpr_; - void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { - this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + void VisitLeaf(const Expr& expr) override { + MixedModeVisitor::VisitLeaf(expr); + SetScope(expr); } - void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const FunctionNode* function_node) override { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + auto node = graph_->item_to_node(GetRef(function_node)); + basic_block_stack_.push_back(node); + ExprVisitor::VisitExpr_(function_node); + basic_block_stack_.pop_back(); } - void VisitExpr_(const RefReadNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const IfNode* if_node) override { + VisitExpr(if_node->cond); + auto node = graph_->item_to_node(GetRef(if_node)); + basic_block_stack_.push_back(node); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + basic_block_stack_.pop_back(); } - void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const LetNode* let_node) override { + auto pre_visit = [&](const LetNode* op) { + VisitExpr(op->value); + VisitExpr(op->var); + }; + auto post_visit = [&](const LetNode* op) { + VisitExpr(op->body); + if (let_node != op) { + visit_counter_[op]++; + SetScope(GetRef(op)); + } + }; + ExpandANormalForm(let_node, pre_visit, post_visit); } - void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { - for (const Type& t : op->inputs) { - this->VisitType(t); + class PatternBlocker : public PatternVisitor { + public: + explicit PatternBlocker(Blocker* scoper) : scoper_(scoper) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + scoper_->SetScope(pattern_var_node->var); } - this->VisitType(op->belong_to); - } - void VisitExpr_(const MatchNode* op, NodePtr parent) override { - this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); - for (const Clause& c : op->clauses) { - this->VisitClause(c, graph_.node_map_[GetRef(op)]); + Blocker* scoper_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + VisitExpr(match_node->data); + auto node = graph_->item_to_node(GetRef(match_node)); + basic_block_stack_.push_back(node); + for (const Clause& c : match_node->clauses) { + PatternBlocker pattern_scoper(this); + pattern_scoper.VisitPattern(c->lhs); + VisitExpr(c->rhs); } + basic_block_stack_.pop_back(); } - void VisitClause(const Clause& op, NodePtr parent) { - this->VisitPattern(op->lhs); - this->VisitExpr(op->rhs, parent); + void SetScope(const Expr& expr) { + auto node = graph_->item_to_node(expr); + if (!basic_block_stack_.empty()) { + node->basic_block_ = basic_block_stack_.back(); + VLOG(2) << "node index " << node->index_ << " has basic block index " + << node->basic_block_->index_; + } else { + VLOG(2) << "node index " << node->index_ << " has no basic block"; + } } - void VisitPattern(const Pattern& p) { return; } - - void VisitType(const Type& t) { return; } + std::unique_ptr> graph_; + std::vector::Node*> basic_block_stack_; }; - return Annotator(Creator().CreateGraph(expr)).Annotate(); + + return Blocker(Annotator(Creator().CreateGraph(expr)).Annotate()).Scope(expr); } -IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ +std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern) { + /*! \brief Creates an IndexedGraph and determines topological order */ class Creator : public DFPatternVisitor { public: - IndexedGraph CreateGraph(const DFPattern& pattern) { + std::unique_ptr> CreateGraph(const DFPattern& pattern) { + graph_ = std::make_unique>(); VisitDFPattern(pattern); - graph_.node_map_[pattern]->is_external_ = true; + graph_->item_to_node(pattern)->is_external_ = true; return std::move(graph_); } @@ -215,121 +344,135 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern(const DFPattern& pattern) override { if (this->visited_.count(pattern.get()) == 0) { DFPatternVisitor::VisitDFPattern(pattern); - auto node = std::make_shared::Node>(pattern, index_++); - graph_.node_map_[pattern] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(pattern); } } - IndexedGraph graph_; - size_t index_ = 0; + + std::unique_ptr> graph_; }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree * analysis. * * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined * topological order instead of recursing. */ - class Annotator : public DFPatternFunctor { + class Annotator : public DFPatternFunctor { public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { + Annotator(std::unique_ptr> graph) : graph_(std::move(graph)) {} + + std::unique_ptr> Annotate() { // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + VisitDFPattern(graph_->index_to_node(index)->ref()); } - graph_.PostDom(); // do the dominator analysis + graph_->PostDom(); return std::move(graph_); } /*! Default visitation pushes the parent to the child's outputs */ - void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { - auto current = graph_.node_map_[pattern]; + void AddOutput(const DFPattern& pattern, IndexedGraph::Node* parent) { + auto current = graph_->item_to_node(pattern); if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); + current->outputs_.push_back(parent); + parent->inputs_.push_back(current); } } protected: - IndexedGraph graph_; - void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const AltPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->left, node); + AddOutput(op->right, node); } - void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const AttrPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const CallPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->op, node); if (op->args.defined()) { for (auto arg : op->args) { - VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + AddOutput(arg, node); } } } - void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ConstantPatternNode* op) override {} - void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const DataTypePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const DominatorPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->parent, node); + AddOutput(op->path, node); + AddOutput(op->child, node); } - void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ExprPatternNode* op) override {} - void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override { + void VisitDFPattern_(const FunctionPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); if (op->params.defined()) { for (auto param : op->params) { - VisitDFPattern(param, graph_.node_map_[GetRef(op)]); + AddOutput(param, node); } } - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + AddOutput(op->body, node); } - void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const ShapePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const TupleGetItemPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->tuple, node); } - void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { + void VisitDFPattern_(const TuplePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); if (op->fields.defined()) { for (auto field : op->fields) { - VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + AddOutput(field, node); } } } - void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->cond, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->true_branch, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const IfPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->cond, node); + AddOutput(op->true_branch, node); + AddOutput(op->false_branch, node); } - void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->var, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->value, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const LetPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->var, node); + AddOutput(op->value, node); + AddOutput(op->body, node); } - void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const TypePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const VarPatternNode* op) override {} + + void VisitDFPattern_(const WildcardPatternNode* op) override {} - void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} + std::unique_ptr> graph_; }; + return Annotator(Creator().CreateGraph(pattern)).Annotate(); } diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index d073bcaeea5c9..365c2f75fe7c1 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -19,7 +19,7 @@ /*! * \file src/relay/ir/indexed_graph.h - * \brief A pattern matcher for matching dataflow properties. + * \brief A graph representation of the dataflow in a Relay expression. */ #ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ #define TVM_RELAY_IR_INDEXED_GRAPH_H_ @@ -36,47 +36,102 @@ namespace tvm { namespace relay { +/*! \brief The index of a node in the post-dfs traversal of overall expression. */ +using PostDfsIndex = size_t; + /*! - * \brief A Wrapper around a templated graph type - * Holds a forward-backward indexed representation of the graph and a dominator tree representation - * of the graph + * \brief Represents the dataflow of an expression (or dataflow pattern) as a graph which is + * overlaid on the underlying expression (or dataflow pattern) graph. + * + * Each graph node references the corresponding sub-expression (or dataflow sub-pattern) node, + * and captures: + * - dataflow inputs + * - dataflow outputs (or a flag indicating the node is an implied output) + * - dominator parent + * - dominator children + * - basic block * - * This class is templated and the implementaiton is in the header file so we can analyze both - * DFPattern and Expr with the same infrastructure. + * This class is templated so we can analyze both DFPatterns and Exprs with the same infrastructure. * - * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. */ template class IndexedGraph { public: - /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ + using TNode = typename T::ContainerType; + + /*! \brief A Node in the dataflow graph. */ struct Node { /*! \brief Node Constructor - * \param ref The input graph node + * \param node_ref The input graph node * \param index The index of the node in toplogical order */ - Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + Node(const TNode* node_ref, PostDfsIndex index) : node_ref_(node_ref), index_(index) {} - /*! \brief The input node */ - const T ref_; - /*! \brief The topological order index */ - const size_t index_; + /*! \brief The underlying expression or pattern node. */ + const TNode* node_ref_; - /*! \brief A boolean to determine if this node is external to the graph */ + T ref() const { + ICHECK(node_ref_ != nullptr); + return GetRef(node_ref_); + } + + /*! + * \brief The index of this node in post-dfs order. If left.index_ > right.index_ then + * left does not flow into right. If left.index_ = right.index_ then left and right are + * the same node. + */ + const PostDfsIndex index_; + + /*! \brief If true this node has implicit outputs, for example as the result of a function. */ bool is_external_ = false; - /*! \brief The forward inputs of the node */ + /*! \brief Immediate dataflow inputs to this node. */ std::vector inputs_; - /*! \brief The forward outputs/users of the node */ + /*! \brief Immediate dataflow outputs of this node -- may be empty if is_external_ is true. */ std::vector outputs_; - /*! \brief The depth of the node in the dominator tree */ + /*! + * \brief The node representing the 'basic block' containing this node: + * - Function bodies start a new basic block for their bodies. + * - The true and false branches of an if start their own blocks. + * - The arms of a match each have their own blocks. + */ + Node* basic_block_ = nullptr; + + /*! \brief The depth of this node in the dominator tree */ size_t depth_ = 0; - /*! \brief The dominator parent/final user of the outputs of this node */ - Node* dominator_parent_; - /*! \brief The nodes this node dominates */ + /*! + * \brief The dominator parent of this node. This is the node N with least index such that + * all possible dataflows from this node pass through N. + */ + Node* dominator_parent_ = nullptr; + /*! \brief The nodes this node dominates. */ std::vector dominator_children_; - bool Dominates(const Node* other) { + /*! + * Add to \p nodes all the nodes which are strictly downstream of \p this, ie can be + * reached by following output paths. + */ + void AccumulateDownstreamNodes(std::unordered_set& nodes) const { + std::stack stack; + stack.push(this); + while (!stack.empty()) { + const Node* current = stack.top(); + stack.pop(); + for (auto node : current->outputs_) { + if (nodes.count(node) == 0) { + stack.push(node); + nodes.insert(node); + } + } + } + } + + /*! + * \brief Returns true if \p this is a dominator of \p other. Ie all dataflow paths from \p + * other pass through \p this. + */ + bool Dominates(const Node* other) const { std::stack stack; std::unordered_set visited; stack.push(this); @@ -97,10 +152,11 @@ class IndexedGraph { return false; } }; + /*! \brief Construct the domination tree inside IndexedGraph */ void PostDom() { - for (size_t i = topological_order_.size(); i != 0; --i) { - size_t index = i - 1; + for (PostDfsIndex i = topological_order_.size(); i != 0; --i) { + PostDfsIndex index = i - 1; auto* current = topological_order_[index].get(); if (current->is_external_) { current->depth_ = 1; @@ -109,16 +165,41 @@ class IndexedGraph { auto parent = LeastCommonAncestor(current->outputs_); current->depth_ = parent ? parent->depth_ + 1 : 1; current->dominator_parent_ = parent; - parent->dominator_children_.push_back(current); + if (parent) { + parent->dominator_children_.push_back(current); + } } } } - /*! \brief Map of input nodes to IndexedGraph Nodes */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map_; - /*! \brief Topological IndexedGraph Nodes */ - std::vector> topological_order_; - protected: + PostDfsIndex size() const { return topological_order_.size(); } + + Node* item_to_node(const T& item) { return item_to_node(item.get()); } + const Node* item_to_node(const T& item) const { return item_to_node(item.get()); } + + Node* item_to_node(const TNode* item) { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef(item)); + return itr->second; + } + + const Node* item_to_node(const TNode* item) const { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef(item)); + return itr->second; + } + + Node* index_to_node(PostDfsIndex index) { + ICHECK_LT(index, topological_order_.size()) << index; + return topological_order_[index].get(); + } + + const Node* index_to_node(PostDfsIndex index) const { + ICHECK_LT(index, topological_order_.size()) << index; + return topological_order_[index].get(); + } + + private: /*! \brief Find the least common ancestor of all outputs of a node */ Node* LeastCommonAncestor(const std::vector& outputs) { if (outputs.size() == 0) { @@ -150,13 +231,35 @@ class IndexedGraph { } return lhs; } + + void AddNode(const T& item) { + PostDfsIndex index = topological_order_.size(); + VLOG(2) << "node index " << index << " is:\n" << PrettyPrint(item); + auto node = std::make_unique(item.get(), index); + node_map_[item.get()] = node.get(); + topological_order_.emplace_back(std::move(node)); + } + + /*! \brief Map from underlying sub-expressions or dataflow sub-pattern graph nodes. */ + std::unordered_map node_map_; + /*! \brief All nodes in increasing post-dfs index order. This vector owns all the nodes. */ + std::vector> topological_order_; + + friend std::unique_ptr> CreateIndexedGraph(const Expr& expr); + friend std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern); }; -/*! \brief Create an Indexed Graph based on an Expr */ -IndexedGraph CreateIndexedGraph(const Expr& expr); -/*! \brief Create an Indexed Graph based on an DFPattern */ -IndexedGraph CreateIndexedGraph(const DFPattern& pattern); +/*! \brief Returns an Indexed Graph for \p expr, which much outlive the result. */ +std::unique_ptr> CreateIndexedGraph(const Expr& expr); + +/*! + * \brief Returns an Indexed Graph for \p pattern, which must outlive the result. + * The dataflow for a pattern mimics the dataflow for the expression which would match + * that pattern. + */ +std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern); } // namespace relay } // namespace tvm + #endif // TVM_RELAY_IR_INDEXED_GRAPH_H_ diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 27f295b8b39dc..cc16ead38f084 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1010,6 +1010,7 @@ Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT f - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(2) .add_argument("tensor_a", "3D Tensor", "The first input.") .add_argument("tensor_b", "3D Tensor", "The second input.") diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index e25b8db152c49..6050d570b1e43 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -88,6 +88,26 @@ static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion"); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.link_params", Bool); +std::string P2S(OpPatternKind pat) { + switch (pat) { + case kElemWise: + return "E"; + case kBroadcast: + return "B"; + case kInjective: + return "I"; + case kCommReduce: + return "R"; + case kOutEWiseFusable: + return "A"; + case kTuple: + return "T"; + case kOpaque: + return "O"; + } + return "?"; +} + /*! * \brief Indexed data flow graph in forward direction. * This is a temporary data structure used for operator fusion analysis. @@ -129,15 +149,27 @@ class IndexedForwardGraph { /*! \brief Dump the graph into string. */ void DebugDump() { std::ostringstream os; + os << "\n"; for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; - os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; + os << "node[" << i << "], "; + auto ref = GetRef(node->ref); + if (const auto* call_node = ref.as()) { + os << call_node->op << "(...)"; + } else if (ref.as()) { + os << "let ..."; + } else if (ref.as()) { + os << "const ..."; + } else { + os << ref; + } + os << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; } os << "]\n"; } - LOG(INFO) << os.str(); + VLOG(1) << "indexed graph:\n" << os.str(); } /*! * \brief create a indexed forward graph. @@ -264,6 +296,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { OpPatternKind edge_pattern = op_pattern; if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr && attr_equal_(rtype->shape, arg_type->shape)) { + VLOG(1) << "weird case: revising from broadcast to elemwise on edge for call"; edge_pattern = kElemWise; } this->Update(call->args[i], node, edge_pattern); @@ -373,7 +406,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor { }; IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { - return Creator(arena).Prepare(body); + IndexedForwardGraph result = Creator(arena).Prepare(body); + // result.DebugDump(); + return result; } /*! @@ -447,9 +482,9 @@ class DominatorTree { * The combined edge pattern across all the parents. * \return The least common ancestor of all nodes. */ - Node* LeastCommonAncestor(const LinkedList& input_nodes, + Node* LeastCommonAncestor(const LinkedList& outputs, OpPatternKind* edge_pattern) { - auto link = input_nodes.head; + auto link = outputs.head; if (link == nullptr) { return nullptr; } @@ -510,8 +545,8 @@ DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForward */ class GraphPartitioner { public: - explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) - : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} + explicit GraphPartitioner(support::Arena* arena, int fuse_opt_level, size_t max_fuse_depth) + : arena_(arena), fuse_opt_level_(fuse_opt_level), max_fuse_depth_(max_fuse_depth) {} /*! * \brief Group as a union find data structure. */ @@ -562,7 +597,7 @@ class GraphPartitioner { /*! \brief The internal arena for temporary space. */ support::Arena* arena_; /*! \brief optimization level for fuse operation. */ - int opt_level_; + int fuse_opt_level_; /*! \brief The maximum number of operations in one fused function */ size_t max_fuse_depth_; /*! \brief The internal groups. */ @@ -577,6 +612,8 @@ class GraphPartitioner { Group* gnode = groups_[src->index]; ICHECK(gnode != nullptr); gnode = gnode->FindRoot(); + VLOG(1) << "primitive checking path from " << src->index << " to " << sink->index + << " with src root pat " << P2S(gnode->pattern); if (!fcond(gnode->pattern, src == sink)) return false; if (src == sink) return true; for (auto link = src->outputs.head; link != nullptr; link = link->next) { @@ -598,6 +635,7 @@ class GraphPartitioner { */ template bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { + VLOG(1) << "checking path from " << src->index << " to " << sink->index; ICHECK(!src->extern_ref); visited_.clear(); ICHECK(src != sink); @@ -630,11 +668,16 @@ class GraphPartitioner { if (child->anchor_ref != nullptr) { ICHECK(parent->anchor_ref == nullptr); parent->anchor_ref = child->anchor_ref; - parent->pattern = CombinePattern(child->pattern, parent->pattern); + OpPatternKind new_pattern = CombinePattern(child->pattern, parent->pattern); + VLOG(1) << "binding anchor, so combining patterns " << P2S(child->pattern) << " and " + << P2S(parent->pattern) << " to " << P2S(new_pattern); + parent->pattern = new_pattern; } } // Internal implelementation of CommitFuse void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) { + VLOG(1) << "commit fusing path from " << src->index << " to " << sink->index + << " for target pattern " << P2S(target->pattern); if (src == sink) return; if (visited_.count(src)) return; visited_.insert(src); @@ -653,6 +696,7 @@ class GraphPartitioner { * \note sink must be a post-dominator of src. */ void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { + VLOG(1) << "fusing path from " << src->index << " to " << sink->index; Group* target = groups_[sink->index]; visited_.clear(); ICHECK(src != sink); @@ -703,94 +747,151 @@ class GraphPartitioner { // execute the fusion algorithm. void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) { + VLOG_CONTEXT << "RunFuse(" << phase << ")"; for (size_t nid = 0; nid < groups_.size(); ++nid) { // the group of current node has been specified already. auto* graph_node = graph.post_dfs_order[nid]; auto* dom_node = post_dom_tree.nodes[nid]; Group* group_node = groups_[nid]; ICHECK(group_node != nullptr); + VLOG_CONTEXT << nid << ":" << P2S(group_node->pattern); // no actions for opaque nodes - if (group_node->pattern == kOpaque) continue; + if (group_node->pattern == kOpaque) { + VLOG(1) << "opaque, ignoring"; + continue; + } // no actions needed if the current node have no dominator - if (dom_node->parent == nullptr) continue; + if (dom_node->parent == nullptr) { + VLOG(1) << "no dominator, ignoring"; + continue; + } ICHECK(!graph_node->extern_ref); size_t dom_parent_gindex = dom_node->parent->gnode->index; + Group* dom_parent_group = groups_[dom_parent_gindex]; + ICHECK(dom_parent_group != nullptr); + Group* dom_root_group = dom_parent_group->FindRoot(); + ICHECK(dom_root_group != nullptr); + + VLOG(1) << "dominator: " << dom_node->parent->gnode->index + << ", dom pat: " << P2S(dom_node->pattern) + << ", dom parent group pat: " << P2S(dom_parent_group->pattern) + << ", dom root group pat: " << P2S(dom_root_group->pattern); // refuse the fusion if too many ops are going to be fused together - if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) + if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) { + VLOG(1) << "already fused to max depth, ignoring"; continue; + } if (phase == 2) { // Fuse injective ops into intermediate tuples, if any - if (group_node->pattern > kInjective) continue; - Group* dom_parent_group = groups_[dom_parent_gindex]; - Group* dom_root_group = dom_parent_group->FindRoot(); - // If dom node group has a tuple as its root, we do not fuse tuple fields into it - if (dom_root_group->pattern == kTuple) continue; - if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { - // Now we know the tuple has been fused into subsequent injective ops - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - // dom_root_group can also be tuple, as in inception layers - // CheckPath is needed to avoid fusing two intermediate tuples - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } + if (!(group_node->pattern == kElemWise || group_node->pattern == kBroadcast || + group_node->pattern == kInjective)) { + VLOG(1) << "not injective, ignoring"; + continue; + } + if (dom_parent_group->pattern != kTuple) { + VLOG(1) << "dominator is not tuple, ignoring"; + continue; + } + if (!(dom_root_group->pattern == kElemWise || dom_root_group->pattern == kBroadcast || + dom_root_group->pattern == kInjective)) { + VLOG(1) << "dom root group not injective, ignoring"; + continue; + } + // Now we know the tuple has been fused into subsequent injective ops + auto fcond = [](OpPatternKind kind, bool is_sink) { + return kind == kElemWise || kind == kBroadcast || kind == kInjective; + }; + // dom_root_group can also be tuple, as in inception layers + // CheckPath is needed to avoid fusing two intermediate tuples + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + VLOG(1) << "fusing elemwise/broadcast/injective into tuple"; + CommitFuse(graph_node, dom_node->parent->gnode); + } else { + VLOG(1) << "invalid path for tuple, ignoring"; } continue; } // Skip if current node is already fused to the parent. - if (groups_[dom_parent_gindex] != nullptr && - group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { + if (group_node->FindRoot() == dom_root_group) { + VLOG(1) << "already fused to dominator, ignoring"; continue; } // Do not fuse into tuple for now - if (groups_[dom_parent_gindex]->pattern == kTuple) continue; + if (dom_parent_group->pattern == kTuple) { + VLOG(1) << "dominator is tuple, ignoring"; + continue; + } // Try to fuse current node to its post-dominator. if (group_node->pattern == kOutEWiseFusable) { - if (phase != 0) continue; + if (phase != 0) { + VLOG(1) << "not phase 0, ignoring"; + continue; + } // Path for OutEWiseFusable: conv2d // Check if the dominator relation is elemwise. - if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { - ICHECK(dom_node->parent->gnode != nullptr); + if (true || dom_node->pattern == kElemWise) { // The fuse can be executed if all the intermediate ops are still broadcast. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; + auto fcond = [](OpPatternKind kind, bool is_sink) { + return kind == kElemWise || kind == kBroadcast; + }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + VLOG(1) << "fusing for anchor"; CommitFuse(graph_node, dom_node->parent->gnode); + } else { + VLOG(1) << "invalid path for anchor, ignoring"; } + } else { + VLOG(1) << "dominator pattern not elemwise, ignoring"; } - } else if (group_node->pattern <= kBroadcast) { + } else if (group_node->pattern == kElemWise || group_node->pattern == kBroadcast) { // Pre-condition: can only be fused to parent which is injective or reduction. - if (dom_node->parent != nullptr && - (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { + if (true || dom_node->pattern == kElemWise || dom_node->pattern == kBroadcast || + dom_node->pattern == kInjective || dom_node->pattern == kCommReduce) { // Check if all the intermediate ops are still broadcast. // The final terminal node can already be fused to a OutEWiseFusable group. auto fcond = [](OpPatternKind kind, bool is_sink) { if (!is_sink) { // Elemwise, broadcast, and injective ops on the parallel branches // are allowed be fused to the elemwise/broadcast anchor. - return kind <= kInjective; + return kind == kElemWise || kind == kBroadcast || kind == kInjective; } else { - return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || - kind == kOutEWiseFusable); + return (kind == kElemWise || kind == kBroadcast || kind == kInjective || + kind == kCommReduce); } }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + VLOG(1) << "fusing elemwise/broadcast"; CommitFuse(graph_node, dom_node->parent->gnode); + } else { + VLOG(1) << "invalid path for elemwise/broadcast, ignoring"; } + } else { + VLOG(1) << "dominator pattern not reduce, ignoring"; } } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { // defer injective fusion to second phase. // so conv2d always finishes fusing. - if (phase != 1) continue; + if (phase != 1) { + VLOG(1) << "not phase 1, ignoring"; + continue; + } // Check if all path are injective. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + auto fcond = [](OpPatternKind kind, bool is_sink) { + return kind == kElemWise || kind == kBroadcast || kind == kInjective; + }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + VLOG(1) << "fusing injective/tuple"; CommitFuse(graph_node, dom_node->parent->gnode); + } else { + VLOG(1) << "invalid path for injective/tuple, ignoring"; } + } else if (group_node->pattern == kCommReduce) { + VLOG(1) << "reduce, ignoring"; } else { - // do nothing. - ICHECK(group_node->pattern == kCommReduce); + ICHECK(false) << "unsupported pattern " << P2S(group_node->pattern); } } } @@ -799,7 +900,7 @@ class GraphPartitioner { std::vector GraphPartitioner::Partition( const IndexedForwardGraph& graph) { this->InitGroups(graph); - if (opt_level_ == 0) return std::move(groups_); + if (fuse_opt_level_ == 0) return std::move(groups_); // get post dominator tree auto post_dom_tree = DominatorTree::PostDom(arena_, graph); // run fusion algorithm. @@ -818,15 +919,9 @@ class FuseMutator : private MixedModeMutator { // Run the transform Expr Transform(const Expr& body) { - return Transform(body, fuse_opt_level_, max_fuse_depth_, link_params_); - } - - protected: - // Run the transform - Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) { // setup the group map. auto graph = IndexedForwardGraph::Create(&arena_, body); - auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); + auto groups = GraphPartitioner(&arena_, fuse_opt_level_, max_fuse_depth_).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { ICHECK(graph.post_dfs_order[nid]->ref != nullptr); gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; @@ -1034,7 +1129,7 @@ class FuseMutator : private MixedModeMutator { os << " /* group=" << group << " */"; return os.str(); }); - LOG(INFO) << "Dump of group info:\n" << text; + VLOG(1) << "group info:\n" << text; } }; @@ -1048,6 +1143,7 @@ namespace transform { Pass FuseOps(int fuse_opt_level) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { + VLOG(1) << "fusing:\n" << PrettyPrint(f); bool link_params = false; Executor executor = m->GetAttr(tvm::attr::kExecutor).value_or(NullValue()); @@ -1057,7 +1153,9 @@ Pass FuseOps(int fuse_opt_level) { link_params = pc->GetConfig("relay.FuseOps.link_params", Bool(link_params)).value(); int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps)); - return Downcast(FuseOps(f, opt_level, max_fuse_depth.value(), link_params, m)); + auto new_f = Downcast(FuseOps(f, opt_level, max_fuse_depth.value(), link_params, m)); + VLOG(1) << "fused:\n" << PrettyPrint(new_f); + return new_f; }; return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"}); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 9c01c40517f4c..774c4eaf1a28e 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -949,6 +949,20 @@ TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const E return InferTypeLocal(expr); }); +Expr InferTypeExpr(const Expr& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main"); + } else { + return mod->Lookup("main").as()->body; + } +} + +TVM_REGISTER_GLOBAL("relay._transform.InferTypeExpr").set_body_typed([](const Expr& expr) { + return InferTypeExpr(expr); +}); + Pass InferType() { auto pass_info = PassInfo(0, "InferType", {}); return tvm::transform::CreateModulePass( diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 4f196265b51b7..9d27d776683d2 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -71,6 +71,12 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, #endif } +nvinfer1::DataType DLDataType2NVDataType(DLDataType data_type) { + ICHECK(data_type.code == kDLFloat && (data_type.bits == 16 || data_type.bits == 32)) + << "Invalid input Tensor type. Only float16 and float32 are supported"; + return (data_type.bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; +} + void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& node) { auto node_name = node.GetOpName(); auto shapes = node.GetOpShape(); @@ -85,13 +91,7 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& shape.erase(shape.begin()); } nvinfer1::Dims dims = VectorToTrtDims(shape); - ICHECK((dtypes[i].bits != 16 || dtypes[i].bits != 32)) - << "Invalid input Tensor type. Float16 and Float32 are supported"; - - auto tensor_dtype = - (dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; - - auto input_tensor = network_->addInput(name.c_str(), tensor_dtype, dims); + auto input_tensor = network_->addInput(name.c_str(), DLDataType2NVDataType(dtypes[i]), dims); node_output_map_[nid].push_back(TensorRTOpInput(input_tensor)); network_input_names_.push_back(name); entry_id_map_[name] = entry_id + i; @@ -146,18 +146,19 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { } params.inputs.push_back(input); } + ICHECK(converter->variable_input_count || converter->input_types.size() == params.inputs.size()) + << "Op " << node.GetOpName() << " expected a different number of inputs."; // Convert op to TRT. converter->Convert(¶ms); // Get outputs. node_output_map_[nid] = {}; - for (auto out : params.outputs) { - auto out_type = params.inputs.at(1).weight.type == params.inputs.at(0).tensor->getType() - ? params.inputs.at(0).tensor->getType() - : params.inputs.at(1).weight.type; - out->setType(out_type); - + std::vector dtype = node.GetOpDataType(); + ICHECK_EQ(params.outputs.size(), dtype.size()); + for (size_t i = 0; i < params.outputs.size(); ++i) { + auto out = params.outputs[i]; + out->setType(DLDataType2NVDataType(dtype[i])); node_output_map_[nid].push_back(TensorRTOpInput(out)); } } diff --git a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h index 58bfcc248f6e8..523676b947022 100755 --- a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h +++ b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h @@ -80,7 +80,7 @@ class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { } num_batches_calibrated_++; // TODO(trevmorr): Free data from previous batch? - return (num_batches_calibrated_ < data_.size()); + return (num_batches_calibrated_ < static_cast(data_.size())); } const void* readCalibrationCache(size_t& length) noexcept override { diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 2c5f293bc4311..30844f3cd848f 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -157,6 +157,8 @@ void TensorRTOpConverter::GetPadding3D(const std::vector& padding, class ActivationOpConverter : public TensorRTOpConverter { public: ActivationOpConverter() : TensorRTOpConverter({kTensor}) {} + ~ActivationOpConverter() = default; + void Convert(TensorRTOpConverterParams* params) const { static const std::unordered_map op_map = { @@ -191,8 +193,9 @@ class ActivationOpConverter : public TensorRTOpConverter { class ElementWiseBinaryOpConverter : public TensorRTOpConverter { public: ElementWiseBinaryOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {} + ~ElementWiseBinaryOpConverter() = default; - void Convert(TensorRTOpConverterParams* params) const { + void Convert(TensorRTOpConverterParams* params) const { static const std::unordered_map op_map = { {"add", nvinfer1::ElementWiseOperation::kSUM}, {"subtract", nvinfer1::ElementWiseOperation::kSUB}, @@ -231,6 +234,7 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter { class Conv1DOpConverter : public TensorRTOpConverter { public: Conv1DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + ~Conv1DOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -282,6 +286,7 @@ class Conv1DOpConverter : public TensorRTOpConverter { class Conv2DOpConverter : public TensorRTOpConverter { public: Conv2DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + ~Conv2DOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -345,6 +350,7 @@ class Conv2DOpConverter : public TensorRTOpConverter { class Conv3DOpConverter : public TensorRTOpConverter { public: Conv3DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + ~Conv3DOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -394,6 +400,7 @@ class Conv3DOpConverter : public TensorRTOpConverter { class DenseOpConverter : public TensorRTOpConverter { public: DenseOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + ~DenseOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -428,6 +435,7 @@ class DenseOpConverter : public TensorRTOpConverter { class BatchNormOpConverter : public TensorRTOpConverter { public: BatchNormOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight, kWeight, kWeight}) {} + ~BatchNormOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -525,6 +533,7 @@ class BatchNormOpConverter : public TensorRTOpConverter { class LayerNormOpConverter : public TensorRTOpConverter { public: LayerNormOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight}) {} + ~LayerNormOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -597,6 +606,7 @@ class LayerNormOpConverter : public TensorRTOpConverter { class BatchFlattenOpConverter : public TensorRTOpConverter { public: BatchFlattenOpConverter() : TensorRTOpConverter({kTensor}) {} + ~BatchFlattenOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { std::vector new_shape{-1}; @@ -610,6 +620,7 @@ class BatchFlattenOpConverter : public TensorRTOpConverter { class SoftmaxOpConverter : public TensorRTOpConverter { public: SoftmaxOpConverter() : TensorRTOpConverter({kTensor}) {} + ~SoftmaxOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -626,6 +637,7 @@ class SoftmaxOpConverter : public TensorRTOpConverter { class PoolingOpConverter : public TensorRTOpConverter { public: PoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + ~PoolingOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -699,6 +711,7 @@ class PoolingOpConverter : public TensorRTOpConverter { class Pooling3DOpConverter : public TensorRTOpConverter { public: Pooling3DOpConverter() : TensorRTOpConverter({kTensor}) {} + ~Pooling3DOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -744,6 +757,7 @@ class Pooling3DOpConverter : public TensorRTOpConverter { class GlobalPoolingOpConverter : public TensorRTOpConverter { public: GlobalPoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + ~GlobalPoolingOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -766,6 +780,7 @@ class GlobalPoolingOpConverter : public TensorRTOpConverter { class ExpandDimsOpConverter : public TensorRTOpConverter { public: ExpandDimsOpConverter() : TensorRTOpConverter({kTensor}) {} + ~ExpandDimsOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -784,6 +799,7 @@ class ExpandDimsOpConverter : public TensorRTOpConverter { class SqueezeOpConverter : public TensorRTOpConverter { public: SqueezeOpConverter() : TensorRTOpConverter({kTensor}) {} + ~SqueezeOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -801,6 +817,7 @@ class SqueezeOpConverter : public TensorRTOpConverter { class UnaryOpConverter : public TensorRTOpConverter { public: UnaryOpConverter() : TensorRTOpConverter({kTensor}) {} + ~UnaryOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { // The following ops are supported by TRT but don't exist in relay yet: @@ -834,6 +851,7 @@ class UnaryOpConverter : public TensorRTOpConverter { class ConcatOpConverter : public TensorRTOpConverter { public: ConcatOpConverter() : TensorRTOpConverter({}, /*variable_input_count=*/true) {} + ~ConcatOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { const int num_inputs = params->inputs.size(); @@ -860,6 +878,7 @@ class ConcatOpConverter : public TensorRTOpConverter { class SplitOpConverter : public TensorRTOpConverter { public: SplitOpConverter() : TensorRTOpConverter({kTensor}) {} + ~SplitOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -907,6 +926,7 @@ class SplitOpConverter : public TensorRTOpConverter { class BiasAddOpConverter : public TensorRTOpConverter { public: BiasAddOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + ~BiasAddOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -940,6 +960,7 @@ class BiasAddOpConverter : public TensorRTOpConverter { class Conv2DTransposeOpConverter : public TensorRTOpConverter { public: Conv2DTransposeOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + ~Conv2DTransposeOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -1010,6 +1031,7 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { class Conv3DTransposeOpConverter : public TensorRTOpConverter { public: Conv3DTransposeOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + ~Conv3DTransposeOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -1066,6 +1088,7 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter { class TransposeOpConverter : public TensorRTOpConverter { public: TransposeOpConverter() : TensorRTOpConverter({kTensor}) {} + ~TransposeOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -1081,6 +1104,7 @@ class TransposeOpConverter : public TensorRTOpConverter { class LayoutTransformOpConverter : public TensorRTOpConverter { public: LayoutTransformOpConverter() : TensorRTOpConverter({kTensor}) {} + ~LayoutTransformOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -1103,12 +1127,15 @@ class LayoutTransformOpConverter : public TensorRTOpConverter { class ReshapeOpConverter : public TensorRTOpConverter { public: ReshapeOpConverter() : TensorRTOpConverter({kTensor}) {} + ~ReshapeOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input->getDimensions()); auto str_newshape = params->node.GetAttr>("newshape"); std::vector new_shape; - const int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0; + int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0; + if(std::stoi(str_newshape[0])==-1) start_index = 0; for (size_t i = start_index; i < str_newshape.size(); ++i) { const int value = std::stoi(str_newshape[i]); ICHECK_GE(value, -1); @@ -1120,7 +1147,8 @@ class ReshapeOpConverter : public TensorRTOpConverter { class PadOpConverter : public TensorRTOpConverter { public: - PadOpConverter() : TensorRTOpConverter({kTensor}) {} + PadOpConverter() : TensorRTOpConverter({kTensor, kIgnored}) {} + ~PadOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -1137,6 +1165,7 @@ class PadOpConverter : public TensorRTOpConverter { class ReduceOpConverter : public TensorRTOpConverter { public: ReduceOpConverter() : TensorRTOpConverter({kTensor}) {} + ~ReduceOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { static const std::unordered_map op_map = { @@ -1172,10 +1201,66 @@ class ReduceOpConverter : public TensorRTOpConverter { } }; +class VarianceOpConverter : public TensorRTOpConverter { + public: + VarianceOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + ICHECK_EQ(std::stoi(params->node.GetAttr>("exclude")[0]), false); + bool keepdims = std::stoi(params->node.GetAttr>("keepdims")[0]); + auto str_axis = params->node.GetAttr>("axis"); + // TODO(trevmorr): Support reduce to scalar. + ICHECK_GT(str_axis.size(), 0); + uint32_t reduce_axes = 0; + for (size_t i = 0; i < str_axis.size(); ++i) { + const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims); + reduce_axes |= 1 << axis; + } + + //auto mean_layer = params->network->addReduce(*input, nvinfer1::ReduceOperation::kAVG, reduce_axes, keepdims); + //ICHECK(mean_layer != nullptr); + //auto mean = mean_layer->getOutput(0); + auto mean = params->inputs.at(1).tensor; + + auto diff_layer = params->network->addElementWise(*input, *mean, nvinfer1::ElementWiseOperation::kSUB); + ICHECK(diff_layer != nullptr); + auto diff = diff_layer->getOutput(0); + + auto square_layer = params->network->addElementWise(*diff, *diff, nvinfer1::ElementWiseOperation::kPROD); + ICHECK(square_layer != nullptr); + auto square = square_layer->getOutput(0); + + auto var_layer = params->network->addReduce(*square, nvinfer1::ReduceOperation::kAVG, reduce_axes, keepdims); + ICHECK(var_layer != nullptr); + auto var = var_layer->getOutput(0); + + const float epsilon = 1e-6; + auto epsilon_tensor = CreateScalar(params, epsilon, var->getDimensions()); + auto eps_add_layer = params->network->addElementWise(*var, *epsilon_tensor, nvinfer1::ElementWiseOperation::kSUM); + ICHECK(eps_add_layer != nullptr); + auto output = eps_add_layer->getOutput(0); + + params->outputs.push_back(output); + + //const float epsilon = std::stof(params->node.GetAttr>("epsilon")[0]); + + //nvinfer1::Weights pes{nvinfer1::DataType::kFLOAT, weight_shift_ptr, gamma.count}; + //auto var = params->network->addScale(*delta, nvinfer1::ScaleMode::kUNIFORM, + + + //nvinfer1::IScaleLayer* scale_layer = params->network->addScale( + // *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power); + + } +}; + + #if TRT_VERSION_GE(5, 1, 5) class StridedSliceOpConverter : public TensorRTOpConverter { public: StridedSliceOpConverter() : TensorRTOpConverter({kTensor}) {} + ~StridedSliceOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; @@ -1205,6 +1290,7 @@ class StridedSliceOpConverter : public TensorRTOpConverter { class AdaptivePoolingOpConverter : public TensorRTOpConverter { public: AdaptivePoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + ~AdaptivePoolingOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; @@ -1235,6 +1321,7 @@ class AdaptivePoolingOpConverter : public TensorRTOpConverter { class BatchMatmulOpConverter : public TensorRTOpConverter { public: BatchMatmulOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {} + ~BatchMatmulOpConverter() = default; void Convert(TensorRTOpConverterParams* params) const { auto transa = std::stoi(params->node.GetAttr>("transpose_a")[0]); @@ -1264,6 +1351,7 @@ GetOpConverters() { map->emplace("nn.conv1d", std::make_shared()); map->emplace("nn.conv2d", std::make_shared()); map->emplace("nn.dense", std::make_shared()); + map->emplace("nn.batch_matmul", std::make_shared()); map->emplace("nn.bias_add", std::make_shared()); map->emplace("add", std::make_shared()); map->emplace("subtract", std::make_shared()); @@ -1296,6 +1384,7 @@ GetOpConverters() { map->emplace("max", std::make_shared()); map->emplace("min", std::make_shared()); map->emplace("mean", std::make_shared()); + map->emplace("variance", std::make_shared()); map->emplace("nn.adaptive_max_pool2d", std::make_shared()); map->emplace("nn.adaptive_avg_pool2d", std::make_shared()); map->emplace("nn.batch_matmul", std::make_shared()); diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h index b71dec00c9bee..fa3d99ac8933b 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.h +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h @@ -49,12 +49,13 @@ namespace contrib { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; /*! - * \brief An input to a op may be either kTensor in the case of nvinfer::ITensor* - * or kWeight for nvinfer1::Weights. + * \brief An input to a op may be either kTensor in the case of nvinfer::ITensor*, + * a kWeight for nvinfer1::Weights, or ignored (eg for the nn.pad value). */ enum TensorRTInputType { kTensor, kWeight, + kIgnored }; /*! @@ -106,6 +107,8 @@ struct TensorRTOpConverterParams { /*! \brief Base class for an op converter from Relay to TRT. */ class TensorRTOpConverter { public: + virtual ~TensorRTOpConverter() = default; + /*! \brief Used to specify whether each input is tensor or weight. */ const std::vector input_types; /*! \brief If set to true, any number of tensor inputs can be used for the op. diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index d8e0231ebcd60..e8d66fa06d819 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -298,8 +298,8 @@ class TensorRTRuntime : public JSONRuntimeBase { } } - LOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_ - << " with batch size " << batch_size; + VLOG(1) << "Finished building TensorRT engine for subgraph " << symbol_name_ + << " with batch size " << batch_size; CacheEngineToDisk(); return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size)); } diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index b4d7b41b7f4ae..fd8e99d2c9997 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -118,20 +118,22 @@ class CUDADeviceAPI final : public DeviceAPI { CUDA_CALL(cudaSetDevice(dev.device_id)); size_t free_mem, total_mem; CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem)); - VLOG(1) << "allocating " << nbytes << " bytes on device, with " << free_mem - << " bytes currently free out of " << total_mem << " bytes available"; + VLOG(1) << "allocating " << nbytes << " bytes on device " << dev.device_id << " with " + << free_mem << " bytes currently free out of " << total_mem << " bytes available"; CUDA_CALL(cudaMalloc(&ret, nbytes)); } + VLOG(1) << "allocated at " << std::hex << reinterpret_cast(ret); return ret; } void FreeDataSpace(Device dev, void* ptr) final { if (dev.device_type == kDLCUDAHost) { - VLOG(1) << "freeing host memory"; + VLOG(1) << "freeing host memory at " << std::hex << reinterpret_cast(ptr); CUDA_CALL(cudaFreeHost(ptr)); } else { CUDA_CALL(cudaSetDevice(dev.device_id)); - VLOG(1) << "freeing device memory"; + VLOG(1) << "freeing device " << dev.device_id << " memory at " << std::hex + << reinterpret_cast(ptr); CUDA_CALL(cudaFree(ptr)); } } diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 81eb30ee12d25..818161064549f 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -83,6 +83,7 @@ class DSOLibrary final : public Library { #else // \brief Linux library handle void* lib_handle_{nullptr}; + std::string filename_; #endif }; @@ -115,12 +116,23 @@ void DSOLibrary::Unload() { #else void DSOLibrary::Load(const std::string& name) { + VLOG(1) << "dlopen('" << name << "')"; + filename_ = name; lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " - << dlerror(); + ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library '" << name + << "': " << dlerror(); } -void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } +void* DSOLibrary::GetSymbol_(const char* name) { + VLOG(1) << "dlsym('" << name << "')"; + void* res = dlsym(lib_handle_, name); + if (res == nullptr) { + VLOG(1) << "Failed to lookup symbol '" << name << "' from dynamic shared library '" << filename_ + << "': " << dlerror(); + // User is responsible for checking nullptr result. + } + return res; +} void DSOLibrary::Unload() { dlclose(lib_handle_); diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 7efa91d912eb3..4d88f8e908969 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -43,6 +43,7 @@ class LibraryModuleNode final : public ModuleNode { const char* type_key() const final { return "library"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + VLOG(1) << "LibraryModule::GetFunction('" << name << "')"; TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { const char* entry_name = @@ -53,7 +54,10 @@ class LibraryModuleNode final : public ModuleNode { } else { faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); } - if (faddr == nullptr) return PackedFunc(); + if (faddr == nullptr) { + // Caller is responsible for checking nullptr result. + return PackedFunc(); + } return packed_func_wrapper_(faddr, sptr_to_self); } diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc index 0f614a7eaff19..1d74ae97d3553 100644 --- a/src/runtime/logging.cc +++ b/src/runtime/logging.cc @@ -251,7 +251,6 @@ TvmLogDebugSettings TvmLogDebugSettings::ParseSpec(const char* opt_spec) { LOG(FATAL) << "TVM_LOG_DEBUG ill-formed, invalid level"; return settings; } - LOG(INFO) << "TVM_LOG_DEBUG enables VLOG statements in '" << name << "' up to level " << level; settings.vlog_level_map_.emplace(name, level_val); } return settings; diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 097d6a2f53e78..26e4144699aeb 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -83,6 +83,7 @@ Module Module::LoadFromFile(const std::string& file_name, const std::string& for fmt = "so"; } std::string load_f_name = "runtime.module.loadfile_" + fmt; + VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << "'"; const PackedFunc* f = Registry::Get(load_f_name); ICHECK(f != nullptr) << "Loader for `." << format << "` files is not registered," << " resolved to (" << load_f_name << ") in the global registry." diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h index e5f236983a735..15af3e06bb240 100644 --- a/src/runtime/vm/pooled_allocator.h +++ b/src/runtime/vm/pooled_allocator.h @@ -52,6 +52,8 @@ class PooledAllocator final : public Allocator { auto&& pool = it->second; auto ret = pool.back(); pool.pop_back(); + VLOG(1) << "reusing buffer of " << ret.size << " bytes at " << + std::hex << reinterpret_cast(ret.data); return ret; } Buffer buf; @@ -67,7 +69,9 @@ class PooledAllocator final : public Allocator { } used_memory_.fetch_add(size, std::memory_order_relaxed); - VLOG(1) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + VLOG(1) << "allocated " << size << " bytes at " << std::hex + << reinterpret_cast(buf.data) << std::dec << ", total used memory is now " + << used_memory_ << " bytes"; return buf; } @@ -77,7 +81,8 @@ class PooledAllocator final : public Allocator { memory_pool_.emplace(buffer.size, std::vector{}); } memory_pool_.at(buffer.size).push_back(buffer); - VLOG(1) << "reclaim buffer " << buffer.size; + VLOG(1) << "reclaiming buffer of " << buffer.size << " bytes at " << std::hex + << reinterpret_cast(buffer.data); } size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); } @@ -88,12 +93,14 @@ class PooledAllocator final : public Allocator { for (auto const& it : memory_pool_) { auto const& pool = it.second; for (auto const& buf : pool) { + VLOG(1) << "freeing " << buf.size << " bytes at " << std::hex + << reinterpret_cast(buf.data); DeviceAPI::Get(buf.device)->FreeDataSpace(buf.device, buf.data); } } memory_pool_.clear(); used_memory_ = 0; - VLOG(1) << "release all buffers"; + VLOG(1) << "released all buffers"; } private: diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 41b9395237ee9..bc4dda9eb4f4f 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -264,6 +264,7 @@ void VirtualMachine::SetOneInput(std::string func_name, const TVMArgValue& tag, } else { LOG(FATAL) << "The type of input tensor tag (" << tag.type_code() << ") doesn't match integer or string"; + inp_index = 0; } ICHECK_LT(inp_index, params_num); @@ -362,11 +363,11 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { - DLOG(INFO) << "Executing Function: " << std::endl << func; + VLOG(1) << "Executing Function: " << std::endl << func; for (int i = 0; i < static_cast(devices_.size()); ++i) { - DLOG(INFO) << "Device " << i << " has device type " << devices_[i].device_type - << " and device id " << devices_[i].device_id - << (i == exec_->host_device_index ? " (using as host device)" : ""); + VLOG(1) << "Device " << i << " has device type " << devices_[i].device_type << " and device id " + << devices_[i].device_id + << (i == exec_->host_device_index ? " (using as host device)" : ""); } InvokeGlobal(func, args); diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index a56e0ad0777cf..d5a06cf4a76ce 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -236,6 +236,51 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, data_ = std::move(node); } +CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, + Array targets) { + VLOG_CONTEXT << "CompilationConfig"; + + auto node = make_object(); + + for (const auto& target : targets) { + VLOG(0) << "Available primitive target " << target->ToDebugString(); + } + + // Capture the arguments in our preferred representation. + node->primitive_targets = std::move(targets); + node->host_target = NullValue(); + + // Complete the targets vector and establish default scopes. After this primitive_targets will + // contain the definitive list of all required targets, target_host will be defined, and + // all primitive targets will have host target_host. + node->EstablishDefaultVirtualDevices(pass_ctx); + + // LEGACY: Reconstruct the target map from all the primitive targets. + // Note that we require pointer equality between targets in legacy_target_map and + // primitive_targets. + for (const auto& primitive_target : node->primitive_targets) { + node->legacy_target_map.Set(Integer(primitive_target->kind->device_type), primitive_target); + } + + ICHECK(node->default_primitive_virtual_device->target.defined()); + ICHECK(node->host_virtual_device->target.defined()); + ICHECK_GT(node->primitive_targets.size(), 0U); + + // Legacy: Some passes only support homogenous compilation and expect the target to be + // given by the global target context. Make this easy to detect. + node->optional_homogeneous_target = + node->legacy_target_map.size() == 1 ? (*node->legacy_target_map.begin()).second : Target(); + + for (const auto& target : node->primitive_targets) { + VLOG(1) << "Target " << target->ToDebugString() << " of device type " + << target->kind->device_type << " is available for primitives"; + } + VLOG(1) << "Using default primitive virtual device " << node->default_primitive_virtual_device; + VLOG(1) << "Using host virtual device " << node->host_virtual_device; + + data_ = std::move(node); +} + TVM_REGISTER_GLOBAL("target.MakeCompilationConfig") .set_body_typed([](const transform::PassContext& pass_ctx, TargetMap legacy_target_map, Target optional_host_target) -> CompilationConfig { diff --git a/src/target/target.cc b/src/target/target.cc index a5c493a582ab2..f76a5c85832ed 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -493,6 +493,28 @@ Target::Target(Target target, Target host) { data_ = std::move(n); } +Target::Target(TargetKind kind, Optional host, String tag, Array keys, + Map attrs) { + auto data = runtime::make_object(); + data->kind = std::move(kind); + data->host = std::move(host); + data->tag = std::move(tag); + data->keys = std::move(keys); + data->attrs = std::move(attrs); + data_ = std::move(data); +} + +bool Target::IsRefinementOf(const Target& that) const { + if (get()->attrs.count("compiler") == 0 && get()->attrs.count("fusion_rule") == 0) { + return StructuralEqual()(*this, that); + } + Map attrs = get()->attrs; + attrs.erase("compiler"); + attrs.erase("fusion_rule"); + Target this_without_extra_attrs(get()->kind, get()->host, get()->tag, get()->keys, attrs); + return StructuralEqual()(this_without_extra_attrs, that); +} + std::vector TargetNode::GetKeys() const { std::vector result; for (auto& expr : keys) { diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index c105979402211..c1bf1c146dbea 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -16,7 +16,6 @@ # under the License. import logging import math -import pytest import tvm from tvm import relay from tvm.contrib.cudnn import conv_output_shape @@ -941,4 +940,7 @@ def test_conv2d_bwd(): if __name__ == "__main__": - pytest.main([__file__]) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/collage/test_collage_partitioner.py b/tests/python/relay/collage/test_collage_partitioner.py new file mode 100644 index 0000000000000..78a48e603ff3c --- /dev/null +++ b/tests/python/relay/collage/test_collage_partitioner.py @@ -0,0 +1,1508 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import onnx +import numpy as np +import logging + +# The following are necessary to force the pattern table to be registered +from tvm.contrib.cutlass import build_cutlass_kernels_vm +from tvm.relay.op.contrib.cublas import partition_for_cublas + +logging.basicConfig(level=logging.INFO) + +# rtx3070 is 8.6 but CUTLASS only supports 8.0 +cutlass_sm = 80 + +TUNING_LOG = "collage_autotvm.tuninglog" + +MODEL_PREFIX = "/home/mbs/gauntlet/models/" +MNIST = {'filename': "mnist-8.onnx", 'input_shapes': {"Input3": [1, 1, 28, 28]}, 'input_dtypes': {"Input3": "float32"}} +GPT2 = {'filename': "gpt2.onnx", 'input_shapes': {"input1": [1, 50, 32]}, 'input_dtypes': {"input1": 'int64'}} + + +def make_const_float32(*shape): + return tvm.relay.const(np.random.rand(*shape).astype("float32")) + + +def mnist(): + const0 = make_const_float32(8, 1, 5, 5) + const1 = make_const_float32(8, 1, 1) + const2 = make_const_float32(16, 8, 5, 5) + const3 = make_const_float32(16, 1, 1) + const4 = make_const_float32(10, 256) + const5 = make_const_float32(1, 10) + metatable = {"relay.Constant": [const0, const1, const2, const3, const4, const5]} + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(1, 1, 28, 28), float32]) -> Tensor[(1, 10), float32] { + %0 = nn.pad(%x, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); + %1 = nn.conv2d(%0, meta[relay.Constant][0], padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); + %2 = add(%1, meta[relay.Constant][1]); + %3 = nn.relu(%2); + %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); + %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); + %6 = nn.conv2d(%5, meta[relay.Constant][2], padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]); + %7 = add(%6, meta[relay.Constant][3]); + %8 = nn.relu(%7); + %9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]); + %10 = reshape(%9, newshape=[1, 256]); + %11 = nn.dense(%10, meta[relay.Constant][4], units=None, out_dtype="float32"); + add(%11, meta[relay.Constant][5]) + } + """, + "from_string", + None, + metatable + ) + + return {'input_shapes': {"x": [1, 1, 28, 28]}, 'input_dtypes': {"x": "float32"}, 'mod': mod, 'params': None} + + +def gpt2(): + const0 = make_const_float32(50257, 768) + const1 = make_const_float32(1, 32, 768) + const2 = make_const_float32(768, 768) + const3 = make_const_float32(2304, 768) + const4 = make_const_float32(768) + const5 = make_const_float32(768) + const6 = make_const_float32(2304) + const7 = make_const_float32(1, 1, 32, 32) + const8 = make_const_float32(1, 1, 32, 32) + const9 = make_const_float32(768) + const10 = make_const_float32(768, 3072) + const11 = make_const_float32(3072, 768) + const12 = make_const_float32(768) + const13 = make_const_float32(768) + const14 = make_const_float32(3072) + const15 = make_const_float32(768) + const16 = make_const_float32(768, 768) + const17 = make_const_float32(2304, 768) + const18 = make_const_float32(768) + const19 = make_const_float32(768) + const20 = make_const_float32(2304) + const21 = make_const_float32(1, 1, 32, 32) + const22 = make_const_float32(1, 1, 32, 32) + const23 = make_const_float32(768) + const24 = make_const_float32(768, 3072) + const25 = make_const_float32(3072, 768) + const26 = make_const_float32(768) + const27 = make_const_float32(768) + const28 = make_const_float32(3072) + const29 = make_const_float32(768) + const30 = make_const_float32(768, 768) + const31 = make_const_float32(2304, 768) + const32 = make_const_float32(768) + const33 = make_const_float32(768) + const34 = make_const_float32(2304) + const35 = make_const_float32(1, 1, 32, 32) + const36 = make_const_float32(1, 1, 32, 32) + const37 = make_const_float32(768) + const38 = make_const_float32(768, 3072) + const39 = make_const_float32(3072, 768) + const40 = make_const_float32(768) + const41 = make_const_float32(768) + const42 = make_const_float32(3072) + const43 = make_const_float32(768) + const44 = make_const_float32(768, 768) + const45 = make_const_float32(2304, 768) + const46 = make_const_float32(768) + const47 = make_const_float32(768) + const48 = make_const_float32(2304) + const49 = make_const_float32(1, 1, 32, 32) + const50 = make_const_float32(1, 1, 32, 32) + const51 = make_const_float32(768) + const52 = make_const_float32(768, 3072) + const53 = make_const_float32(3072, 768) + const54 = make_const_float32(768) + const55 = make_const_float32(768) + const56 = make_const_float32(3072) + const57 = make_const_float32(768) + const58 = make_const_float32(768, 768) + const59 = make_const_float32(2304, 768) + const60 = make_const_float32(768) + const61 = make_const_float32(768) + const62 = make_const_float32(2304) + const63 = make_const_float32(1, 1, 32, 32) + const64 = make_const_float32(1, 1, 32, 32) + const65 = make_const_float32(768) + const66 = make_const_float32(768, 3072) + const67 = make_const_float32(3072, 768) + const68 = make_const_float32(768) + const69 = make_const_float32(768) + const70 = make_const_float32(3072) + const71 = make_const_float32(768) + const72 = make_const_float32(768, 768) + const73 = make_const_float32(2304, 768) + const74 = make_const_float32(768) + const75 = make_const_float32(768) + const76 = make_const_float32(2304) + const77 = make_const_float32(1, 1, 32, 32) + const78 = make_const_float32(1, 1, 32, 32) + const79 = make_const_float32(768) + const80 = make_const_float32(768, 3072) + const81 = make_const_float32(3072, 768) + const82 = make_const_float32(768) + const83 = make_const_float32(768) + const84 = make_const_float32(3072) + const85 = make_const_float32(768) + const86 = make_const_float32(768, 768) + const87 = make_const_float32(2304, 768) + const88 = make_const_float32(768) + const89 = make_const_float32(768) + const90 = make_const_float32(2304) + const91 = make_const_float32(1, 1, 32, 32) + const92 = make_const_float32(1, 1, 32, 32) + const93 = make_const_float32(768) + const94 = make_const_float32(768, 3072) + const95 = make_const_float32(3072, 768) + const96 = make_const_float32(768) + const97 = make_const_float32(768) + const98 = make_const_float32(3072) + const99 = make_const_float32(768) + const100 = make_const_float32(768, 768) + const101 = make_const_float32(2304, 768) + const102 = make_const_float32(768) + const103 = make_const_float32(768) + const104 = make_const_float32(2304) + const105 = make_const_float32(1, 1, 32, 32) + const106 = make_const_float32(1, 1, 32, 32) + const107 = make_const_float32(768) + const108 = make_const_float32(768, 3072) + const109 = make_const_float32(3072, 768) + const110 = make_const_float32(768) + const111 = make_const_float32(768) + const112 = make_const_float32(3072) + const113 = make_const_float32(768) + const114 = make_const_float32(768, 768) + const115 = make_const_float32(2304, 768) + const116 = make_const_float32(768) + const117 = make_const_float32(768) + const118 = make_const_float32(2304) + const119 = make_const_float32(1, 1, 32, 32) + const120 = make_const_float32(1, 1, 32, 32) + const121 = make_const_float32(768) + const122 = make_const_float32(768, 3072) + const123 = make_const_float32(3072, 768) + const124 = make_const_float32(768) + const125 = make_const_float32(768) + const126 = make_const_float32(3072) + const127 = make_const_float32(768) + const128 = make_const_float32(768, 768) + const129 = make_const_float32(2304, 768) + const130 = make_const_float32(768) + const131 = make_const_float32(768) + const132 = make_const_float32(2304) + const133 = make_const_float32(1, 1, 32, 32) + const134 = make_const_float32(1, 1, 32, 32) + const135 = make_const_float32(768) + const136 = make_const_float32(768, 3072) + const137 = make_const_float32(3072, 768) + const138 = make_const_float32(768) + const139 = make_const_float32(768) + const140 = make_const_float32(3072) + const141 = make_const_float32(768) + const142 = make_const_float32(768, 768) + const143 = make_const_float32(2304, 768) + const144 = make_const_float32(768) + const145 = make_const_float32(768) + const146 = make_const_float32(2304) + const147 = make_const_float32(1, 1, 32, 32) + const148 = make_const_float32(1, 1, 32, 32) + const149 = make_const_float32(768) + const150 = make_const_float32(768, 3072) + const151 = make_const_float32(3072, 768) + const152 = make_const_float32(768) + const153 = make_const_float32(768) + const154 = make_const_float32(3072) + const155 = make_const_float32(768) + const156 = make_const_float32(768, 768) + const157 = make_const_float32(2304, 768) + const158 = make_const_float32(768) + const159 = make_const_float32(768) + const160 = make_const_float32(2304) + const161 = make_const_float32(1, 1, 32, 32) + const162 = make_const_float32(1, 1, 32, 32) + const163 = make_const_float32(768) + const164 = make_const_float32(768, 3072) + const165 = make_const_float32(3072, 768) + const166 = make_const_float32(768) + const167 = make_const_float32(768) + const168 = make_const_float32(3072) + const169 = make_const_float32(768) + const170 = make_const_float32(768) + const171 = make_const_float32(768) + metatable = {"relay.Constant": [ + const0, + const1, + const2, + const3, + const4, + const5, + const6, + const7, + const8, + const9, + const10, + const11, + const12, + const13, + const14, + const15, + const16, + const17, + const18, + const19, + const20, + const21, + const22, + const23, + const24, + const25, + const26, + const27, + const28, + const29, + const30, + const31, + const32, + const33, + const34, + const35, + const36, + const37, + const38, + const39, + const40, + const41, + const42, + const43, + const44, + const45, + const46, + const47, + const48, + const49, + const50, + const51, + const52, + const53, + const54, + const55, + const56, + const57, + const58, + const59, + const60, + const61, + const62, + const63, + const64, + const65, + const66, + const67, + const68, + const69, + const70, + const71, + const72, + const73, + const74, + const75, + const76, + const77, + const78, + const79, + const80, + const81, + const82, + const83, + const84, + const85, + const86, + const87, + const88, + const89, + const90, + const91, + const92, + const93, + const94, + const95, + const96, + const97, + const98, + const99, + const100, + const101, + const102, + const103, + const104, + const105, + const106, + const107, + const108, + const109, + const110, + const111, + const112, + const113, + const114, + const115, + const116, + const117, + const118, + const119, + const120, + const121, + const122, + const123, + const124, + const125, + const126, + const127, + const128, + const129, + const130, + const131, + const132, + const133, + const134, + const135, + const136, + const137, + const138, + const139, + const140, + const141, + const142, + const143, + const144, + const145, + const146, + const147, + const148, + const149, + const150, + const151, + const152, + const153, + const154, + const155, + const156, + const157, + const158, + const159, + const160, + const161, + const162, + const163, + const164, + const165, + const166, + const167, + const168, + const169, + const170, + const171, + ]} + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(1, 50, 32), int64]) -> (Tensor[(1, 50, 32, 768), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32], + Tensor[(2, 50, 12, 32, 64), float32]) { + %0 = reshape(%x, newshape=[-1, 32]); + %1 = less(%0, 0i64); + %2 = add(%0, 50257i64); + %3 = where(%1, %2, %0); + %4 = take(meta[relay.Constant][0], %3, axis=0); + %5 = add(%4, meta[relay.Constant][1]); + %6 = mean(%5, axis=[-1], keepdims=True); + %7 = subtract(%5, %6); + %8 = power(%7, 2f); + %9 = mean(%8, axis=[-1], keepdims=True); + %10 = add(%9, 1e-05f); + %11 = sqrt(%10); + %12 = divide(%7, %11); + %13 = multiply(%12, meta[relay.Constant][4]); + %14 = add(%13, meta[relay.Constant][5]); + %15 = reshape(%14, newshape=[-1, 768]); + %17 = nn.dense(%15, meta[relay.Constant][3], units=2304); + %18 = add(%17, meta[relay.Constant][6]); + %19 = reshape(%18, newshape=[50, 32, 2304]); + %20 = split(%19, indices_or_sections=[768, 1536], axis=2); + %21 = %20.0; + %22 = reshape(%21, newshape=[50, 32, 12, 64]); + %23 = transpose(%22, axes=[0, 2, 1, 3]); + %24 = %20.1; + %25 = reshape(%24, newshape=[50, 32, 12, 64]); + %26 = transpose(%25, axes=[0, 2, 3, 1]); + %27 = reshape(%26, newshape=[-1, 64, 32]); + %28 = reshape(%23, newshape=[-1, 32, 64]); + %29 = transpose(%27, axes=[0, 2, 1]); + %30 = nn.batch_matmul(%28, %29, out_dtype="float32", transpose_b=True); + %31 = reshape(%30, newshape=[50, 12, 32, 32]); + %32 = divide(%31, 8f); + %33 = multiply(%32, meta[relay.Constant][7]); + %34 = subtract(%33, meta[relay.Constant][8]); + %35 = nn.softmax(%34, axis=3); + %36 = %20.2; + %37 = reshape(%36, newshape=[50, 32, 12, 64]); + %38 = transpose(%37, axes=[0, 2, 1, 3]); + %39 = reshape(%38, newshape=[-1, 32, 64]); + %40 = reshape(%35, newshape=[-1, 32, 32]); + %41 = transpose(%39, axes=[0, 2, 1]); + %42 = nn.batch_matmul(%40, %41, out_dtype="float32", transpose_b=True); + %43 = reshape(%42, newshape=[50, 12, 32, 64]); + %44 = transpose(%43, axes=[0, 2, 1, 3]); + %45 = reshape(%44, newshape=[50, 32, 768]); + %46 = reshape(%45, newshape=[-1, 768]); + %48 = nn.dense(%46, meta[relay.Constant][2], units=768); + %49 = add(%48, meta[relay.Constant][9]); + %50 = reshape(%49, newshape=[50, 32, 768]); + %51 = add(%5, %50); + %52 = mean(%51, axis=[-1], keepdims=True); + %53 = subtract(%51, %52); + %54 = power(%53, 2f); + %55 = mean(%54, axis=[-1], keepdims=True); + %56 = add(%55, 1e-05f); + %57 = sqrt(%56); + %58 = divide(%53, %57); + %59 = multiply(%58, meta[relay.Constant][12]); + %60 = add(%59, meta[relay.Constant][13]); + %61 = reshape(%60, newshape=[-1, 768]); + %63 = nn.dense(%61, meta[relay.Constant][11], units=3072); + %64 = add(%63, meta[relay.Constant][14]); + %65 = reshape(%64, newshape=[50, 32, 3072]); + %66 = power(%65, 3f); + %67 = multiply(%66, 0.044715f); + %68 = add(%65, %67); + %69 = multiply(%68, 0.797885f); + %70 = tanh(%69); + %71 = multiply(%65, 0.5f); + %72 = add(%70, 1f); + %73 = multiply(%71, %72); + %74 = reshape(%73, newshape=[-1, 3072]); + %76 = nn.dense(%74, meta[relay.Constant][10], units=768); + %77 = add(%76, meta[relay.Constant][15]); + %78 = reshape(%77, newshape=[50, 32, 768]); + %80 = add(%51, %78); + %81 = mean(%80, axis=[-1], keepdims=True); + %82 = subtract(%80, %81); + %83 = power(%82, 2f); + %84 = mean(%83, axis=[-1], keepdims=True); + %85 = add(%84, 1e-05f); + %86 = sqrt(%85); + %87 = divide(%82, %86); + %88 = multiply(%87, meta[relay.Constant][18]); + %89 = add(%88, meta[relay.Constant][19]); + %90 = reshape(%89, newshape=[-1, 768]); + %92 = nn.dense(%90, meta[relay.Constant][17], units=2304); + %93 = add(%92, meta[relay.Constant][20]); + %94 = reshape(%93, newshape=[50, 32, 2304]); + %95 = split(%94, indices_or_sections=[768, 1536], axis=2); + %96 = %95.0; + %97 = reshape(%96, newshape=[50, 32, 12, 64]); + %98 = transpose(%97, axes=[0, 2, 1, 3]); + %99 = %95.1; + %100 = reshape(%99, newshape=[50, 32, 12, 64]); + %101 = transpose(%100, axes=[0, 2, 3, 1]); + %102 = reshape(%101, newshape=[-1, 64, 32]); + %103 = reshape(%98, newshape=[-1, 32, 64]); + %104 = transpose(%102, axes=[0, 2, 1]); + %105 = nn.batch_matmul(%103, %104, out_dtype="float32", transpose_b=True); + %106 = reshape(%105, newshape=[50, 12, 32, 32]); + %107 = divide(%106, 8f); + %108 = multiply(%107, meta[relay.Constant][21]); + %109 = subtract(%108, meta[relay.Constant][22]); + %110 = nn.softmax(%109, axis=3); + %111 = %95.2; + %112 = reshape(%111, newshape=[50, 32, 12, 64]); + %113 = transpose(%112, axes=[0, 2, 1, 3]); + %114 = reshape(%113, newshape=[-1, 32, 64]); + %115 = reshape(%110, newshape=[-1, 32, 32]); + %116 = transpose(%114, axes=[0, 2, 1]); + %118 = nn.batch_matmul(%115, %116, out_dtype="float32", transpose_b=True); + %119 = reshape(%118, newshape=[50, 12, 32, 64]); + %120 = transpose(%119, axes=[0, 2, 1, 3]); + %121 = reshape(%120, newshape=[50, 32, 768]); + %122 = reshape(%121, newshape=[-1, 768]); + %124 = nn.dense(%122, meta[relay.Constant][16], units=768); + %125 = add(%124, meta[relay.Constant][23]); + %126 = reshape(%125, newshape=[50, 32, 768]); + %127 = add(%80, %126); + %128 = mean(%127, axis=[-1], keepdims=True); + %129 = subtract(%127, %128); + %130 = power(%129, 2f); + %131 = mean(%130, axis=[-1], keepdims=True); + %132 = add(%131, 1e-05f); + %133 = sqrt(%132); + %134 = divide(%129, %133); + %135 = multiply(%134, meta[relay.Constant][26]); + %136 = add(%135, meta[relay.Constant][27]); + %137 = reshape(%136, newshape=[-1, 768]); + %139 = nn.dense(%137, meta[relay.Constant][25], units=3072); + %140 = add(%139, meta[relay.Constant][28]); + %141 = reshape(%140, newshape=[50, 32, 3072]); + %142 = power(%141, 3f); + %143 = multiply(%142, 0.044715f); + %144 = add(%141, %143); + %145 = multiply(%144, 0.797885f); + %146 = tanh(%145); + %147 = multiply(%141, 0.5f); + %148 = add(%146, 1f); + %149 = multiply(%147, %148); + %150 = reshape(%149, newshape=[-1, 3072]); + %152 = nn.dense(%150, meta[relay.Constant][24], units=768); + %153 = add(%152, meta[relay.Constant][29]); + %154 = reshape(%153, newshape=[50, 32, 768]); + %156 = add(%127, %154); + %157 = mean(%156, axis=[-1], keepdims=True); + %158 = subtract(%156, %157); + %159 = power(%158, 2f); + %160 = mean(%159, axis=[-1], keepdims=True); + %161 = add(%160, 1e-05f); + %162 = sqrt(%161); + %163 = divide(%158, %162); + %164 = multiply(%163, meta[relay.Constant][32]); + %165 = add(%164, meta[relay.Constant][33]); + %166 = reshape(%165, newshape=[-1, 768]); + %168 = nn.dense(%166, meta[relay.Constant][31], units=2304); + %169 = add(%168, meta[relay.Constant][34]); + %170 = reshape(%169, newshape=[50, 32, 2304]); + %171 = split(%170, indices_or_sections=[768, 1536], axis=2); + %172 = %171.0; + %173 = reshape(%172, newshape=[50, 32, 12, 64]); + %174 = transpose(%173, axes=[0, 2, 1, 3]); + %175 = %171.1; + %176 = reshape(%175, newshape=[50, 32, 12, 64]); + %177 = transpose(%176, axes=[0, 2, 3, 1]); + %178 = reshape(%177, newshape=[-1, 64, 32]); + %179 = reshape(%174, newshape=[-1, 32, 64]); + %180 = transpose(%178, axes=[0, 2, 1]); + %181 = nn.batch_matmul(%179, %180, out_dtype="float32", transpose_b=True); + %182 = reshape(%181, newshape=[50, 12, 32, 32]); + %183 = divide(%182, 8f); + %184 = multiply(%183, meta[relay.Constant][35]); + %185 = subtract(%184, meta[relay.Constant][36]); + %186 = nn.softmax(%185, axis=3); + %187 = %171.2; + %188 = reshape(%187, newshape=[50, 32, 12, 64]); + %189 = transpose(%188, axes=[0, 2, 1, 3]); + %190 = reshape(%189, newshape=[-1, 32, 64]); + %191 = reshape(%186, newshape=[-1, 32, 32]); + %192 = transpose(%190, axes=[0, 2, 1]); + %194 = nn.batch_matmul(%191, %192, out_dtype="float32", transpose_b=True); + %195 = reshape(%194, newshape=[50, 12, 32, 64]); + %196 = transpose(%195, axes=[0, 2, 1, 3]); + %197 = reshape(%196, newshape=[50, 32, 768]); + %198 = reshape(%197, newshape=[-1, 768]); + %200 = nn.dense(%198, meta[relay.Constant][30], units=768); + %201 = add(%200, meta[relay.Constant][37]); + %202 = reshape(%201, newshape=[50, 32, 768]); + %203 = add(%156, %202); + %204 = mean(%203, axis=[-1], keepdims=True); + %205 = subtract(%203, %204); + %206 = power(%205, 2f); + %207 = mean(%206, axis=[-1], keepdims=True); + %208 = add(%207, 1e-05f); + %209 = sqrt(%208); + %210 = divide(%205, %209); + %211 = multiply(%210, meta[relay.Constant][40]); + %212 = add(%211, meta[relay.Constant][41]); + %213 = reshape(%212, newshape=[-1, 768]); + %215 = nn.dense(%213, meta[relay.Constant][39], units=3072); + %216 = add(%215, meta[relay.Constant][42]); + %217 = reshape(%216, newshape=[50, 32, 3072]); + %218 = power(%217, 3f); + %219 = multiply(%218, 0.044715f); + %220 = add(%217, %219); + %221 = multiply(%220, 0.797885f); + %222 = tanh(%221); + %223 = multiply(%217, 0.5f); + %224 = add(%222, 1f); + %225 = multiply(%223, %224); + %226 = reshape(%225, newshape=[-1, 3072]); + %228 = nn.dense(%226, meta[relay.Constant][38], units=768); + %229 = add(%228, meta[relay.Constant][43]); + %230 = reshape(%229, newshape=[50, 32, 768]); + %232 = add(%203, %230); + %233 = mean(%232, axis=[-1], keepdims=True); + %234 = subtract(%232, %233); + %235 = power(%234, 2f); + %236 = mean(%235, axis=[-1], keepdims=True); + %237 = add(%236, 1e-05f); + %238 = sqrt(%237); + %239 = divide(%234, %238); + %240 = multiply(%239, meta[relay.Constant][46]); + %241 = add(%240, meta[relay.Constant][47]); + %242 = reshape(%241, newshape=[-1, 768]); + %244 = nn.dense(%242, meta[relay.Constant][45], units=2304); + %245 = add(%244, meta[relay.Constant][48]); + %246 = reshape(%245, newshape=[50, 32, 2304]); + %247 = split(%246, indices_or_sections=[768, 1536], axis=2); + %248 = %247.0; + %249 = reshape(%248, newshape=[50, 32, 12, 64]); + %250 = transpose(%249, axes=[0, 2, 1, 3]); + %251 = %247.1; + %252 = reshape(%251, newshape=[50, 32, 12, 64]); + %253 = transpose(%252, axes=[0, 2, 3, 1]); + %254 = reshape(%253, newshape=[-1, 64, 32]); + %255 = reshape(%250, newshape=[-1, 32, 64]); + %256 = transpose(%254, axes=[0, 2, 1]); + %257 = nn.batch_matmul(%255, %256, out_dtype="float32", transpose_b=True); + %258 = reshape(%257, newshape=[50, 12, 32, 32]); + %259 = divide(%258, 8f); + %260 = multiply(%259, meta[relay.Constant][49]); + %261 = subtract(%260, meta[relay.Constant][50]); + %262 = nn.softmax(%261, axis=3); + %263 = %247.2; + %264 = reshape(%263, newshape=[50, 32, 12, 64]); + %265 = transpose(%264, axes=[0, 2, 1, 3]); + %266 = reshape(%265, newshape=[-1, 32, 64]); + %267 = reshape(%262, newshape=[-1, 32, 32]); + %268 = transpose(%266, axes=[0, 2, 1]); + %270 = nn.batch_matmul(%267, %268, out_dtype="float32", transpose_b=True); + %271 = reshape(%270, newshape=[50, 12, 32, 64]); + %272 = transpose(%271, axes=[0, 2, 1, 3]); + %273 = reshape(%272, newshape=[50, 32, 768]); + %274 = reshape(%273, newshape=[-1, 768]); + %276 = nn.dense(%274, meta[relay.Constant][44], units=768); + %277 = add(%276, meta[relay.Constant][51]); + %278 = reshape(%277, newshape=[50, 32, 768]); + %279 = add(%232, %278); + %280 = mean(%279, axis=[-1], keepdims=True); + %281 = subtract(%279, %280); + %282 = power(%281, 2f); + %283 = mean(%282, axis=[-1], keepdims=True); + %284 = add(%283, 1e-05f); + %285 = sqrt(%284); + %286 = divide(%281, %285); + %287 = multiply(%286, meta[relay.Constant][54]); + %288 = add(%287, meta[relay.Constant][55]); + %289 = reshape(%288, newshape=[-1, 768]); + %291 = nn.dense(%289, meta[relay.Constant][53], units=3072); + %292 = add(%291, meta[relay.Constant][56]); + %293 = reshape(%292, newshape=[50, 32, 3072]); + %294 = power(%293, 3f); + %295 = multiply(%294, 0.044715f); + %296 = add(%293, %295); + %297 = multiply(%296, 0.797885f); + %298 = tanh(%297); + %299 = multiply(%293, 0.5f); + %300 = add(%298, 1f); + %301 = multiply(%299, %300); + %302 = reshape(%301, newshape=[-1, 3072]); + %304 = nn.dense(%302, meta[relay.Constant][52], units=768); + %305 = add(%304, meta[relay.Constant][57]); + %306 = reshape(%305, newshape=[50, 32, 768]); + %307 = add(%279, %306); + %308 = mean(%307, axis=[-1], keepdims=True); + %309 = subtract(%307, %308); + %310 = power(%309, 2f); + %311 = mean(%310, axis=[-1], keepdims=True); + %312 = add(%311, 1e-05f); + %313 = sqrt(%312); + %314 = divide(%309, %313); + %315 = multiply(%314, meta[relay.Constant][60]); + %316 = add(%315, meta[relay.Constant][61]); + %317 = reshape(%316, newshape=[-1, 768]); + %319 = nn.dense(%317, meta[relay.Constant][59], units=2304); + %320 = add(%319, meta[relay.Constant][62]); + %321 = reshape(%320, newshape=[50, 32, 2304]); + %322 = split(%321, indices_or_sections=[768, 1536], axis=2); + %323 = %322.0; + %324 = reshape(%323, newshape=[50, 32, 12, 64]); + %325 = transpose(%324, axes=[0, 2, 1, 3]); + %326 = %322.1; + %327 = reshape(%326, newshape=[50, 32, 12, 64]); + %328 = transpose(%327, axes=[0, 2, 3, 1]); + %329 = reshape(%328, newshape=[-1, 64, 32]); + %330 = reshape(%325, newshape=[-1, 32, 64]); + %331 = transpose(%329, axes=[0, 2, 1]); + %332 = nn.batch_matmul(%330, %331, out_dtype="float32", transpose_b=True); + %333 = reshape(%332, newshape=[50, 12, 32, 32]); + %334 = divide(%333, 8f); + %335 = multiply(%334, meta[relay.Constant][63]); + %336 = subtract(%335, meta[relay.Constant][64]); + %337 = nn.softmax(%336, axis=3); + %338 = %322.2; + %339 = reshape(%338, newshape=[50, 32, 12, 64]); + %340 = transpose(%339, axes=[0, 2, 1, 3]); + %341 = reshape(%340, newshape=[-1, 32, 64]); + %342 = reshape(%337, newshape=[-1, 32, 32]); + %343 = transpose(%341, axes=[0, 2, 1]); + %344 = nn.batch_matmul(%342, %343, out_dtype="float32", transpose_b=True); + %345 = reshape(%344, newshape=[50, 12, 32, 64]); + %346 = transpose(%345, axes=[0, 2, 1, 3]); + %347 = reshape(%346, newshape=[50, 32, 768]); + %348 = reshape(%347, newshape=[-1, 768]); + %350 = nn.dense(%348, meta[relay.Constant][58], units=768); + %351 = add(%350, meta[relay.Constant][65]); + %352 = reshape(%351, newshape=[50, 32, 768]); + %353 = add(%307, %352); + %354 = mean(%353, axis=[-1], keepdims=True); + %355 = subtract(%353, %354); + %356 = power(%355, 2f); + %357 = mean(%356, axis=[-1], keepdims=True); + %358 = add(%357, 1e-05f); + %359 = sqrt(%358); + %360 = divide(%355, %359); + %361 = multiply(%360, meta[relay.Constant][68]); + %362 = add(%361, meta[relay.Constant][69]); + %363 = reshape(%362, newshape=[-1, 768]); + %365 = nn.dense(%363, meta[relay.Constant][67], units=3072); + %366 = add(%365, meta[relay.Constant][70]); + %367 = reshape(%366, newshape=[50, 32, 3072]); + %368 = power(%367, 3f); + %369 = multiply(%368, 0.044715f); + %370 = add(%367, %369); + %371 = multiply(%370, 0.797885f); + %372 = tanh(%371); + %373 = multiply(%367, 0.5f); + %374 = add(%372, 1f); + %375 = multiply(%373, %374); + %376 = reshape(%375, newshape=[-1, 3072]); + %378 = nn.dense(%376, meta[relay.Constant][66], units=768); + %379 = add(%378, meta[relay.Constant][71]); + %380 = reshape(%379, newshape=[50, 32, 768]); + %381 = add(%353, %380); + %382 = mean(%381, axis=[-1], keepdims=True); + %383 = subtract(%381, %382); + %384 = power(%383, 2f); + %385 = mean(%384, axis=[-1], keepdims=True); + %386 = add(%385, 1e-05f); + %387 = sqrt(%386); + %388 = divide(%383, %387); + %389 = multiply(%388, meta[relay.Constant][74]); + %390 = add(%389, meta[relay.Constant][75]); + %391 = reshape(%390, newshape=[-1, 768]); + %393 = nn.dense(%391, meta[relay.Constant][73], units=2304); + %394 = add(%393, meta[relay.Constant][76]); + %395 = reshape(%394, newshape=[50, 32, 2304]); + %396 = split(%395, indices_or_sections=[768, 1536], axis=2); + %397 = %396.0; + %398 = reshape(%397, newshape=[50, 32, 12, 64]); + %399 = transpose(%398, axes=[0, 2, 1, 3]); + %400 = %396.1; + %401 = reshape(%400, newshape=[50, 32, 12, 64]); + %402 = transpose(%401, axes=[0, 2, 3, 1]); + %403 = reshape(%402, newshape=[-1, 64, 32]); + %404 = reshape(%399, newshape=[-1, 32, 64]); + %405 = transpose(%403, axes=[0, 2, 1]); + %406 = nn.batch_matmul(%404, %405, out_dtype="float32", transpose_b=True); + %407 = reshape(%406, newshape=[50, 12, 32, 32]); + %408 = divide(%407, 8f); + %409 = multiply(%408, meta[relay.Constant][77]); + %410 = subtract(%409, meta[relay.Constant][78]); + %411 = nn.softmax(%410, axis=3); + %412 = %396.2; + %413 = reshape(%412, newshape=[50, 32, 12, 64]); + %414 = transpose(%413, axes=[0, 2, 1, 3]); + %415 = reshape(%414, newshape=[-1, 32, 64]); + %416 = reshape(%411, newshape=[-1, 32, 32]); + %417 = transpose(%415, axes=[0, 2, 1]); + %418 = nn.batch_matmul(%416, %417, out_dtype="float32", transpose_b=True); + %419 = reshape(%418, newshape=[50, 12, 32, 64]); + %420 = transpose(%419, axes=[0, 2, 1, 3]); + %421 = reshape(%420, newshape=[50, 32, 768]); + %422 = reshape(%421, newshape=[-1, 768]); + %424 = nn.dense(%422, meta[relay.Constant][72], units=768); + %425 = add(%424, meta[relay.Constant][79]); + %426 = reshape(%425, newshape=[50, 32, 768]); + %427 = add(%381, %426); + %428 = mean(%427, axis=[-1], keepdims=True); + %429 = subtract(%427, %428); + %430 = power(%429, 2f); + %431 = mean(%430, axis=[-1], keepdims=True); + %432 = add(%431, 1e-05f); + %433 = sqrt(%432); + %434 = divide(%429, %433); + %435 = multiply(%434, meta[relay.Constant][82]); + %436 = add(%435, meta[relay.Constant][83]); + %437 = reshape(%436, newshape=[-1, 768]); + %439 = nn.dense(%437, meta[relay.Constant][81], units=3072); + %440 = add(%439, meta[relay.Constant][84]); + %441 = reshape(%440, newshape=[50, 32, 3072]); + %442 = power(%441, 3f); + %443 = multiply(%442, 0.044715f); + %444 = add(%441, %443); + %445 = multiply(%444, 0.797885f); + %446 = tanh(%445); + %447 = multiply(%441, 0.5f); + %448 = add(%446, 1f); + %449 = multiply(%447, %448); + %450 = reshape(%449, newshape=[-1, 3072]); + %452 = nn.dense(%450, meta[relay.Constant][80], units=768); + %453 = add(%452, meta[relay.Constant][85]); + %454 = reshape(%453, newshape=[50, 32, 768]); + %455 = add(%427, %454); + %456 = mean(%455, axis=[-1], keepdims=True); + %457 = subtract(%455, %456); + %458 = power(%457, 2f); + %459 = mean(%458, axis=[-1], keepdims=True); + %460 = add(%459, 1e-05f); + %461 = sqrt(%460); + %462 = divide(%457, %461); + %463 = multiply(%462, meta[relay.Constant][88]); + %464 = add(%463, meta[relay.Constant][89]); + %465 = reshape(%464, newshape=[-1, 768]); + %467 = nn.dense(%465, meta[relay.Constant][87], units=2304); + %468 = add(%467, meta[relay.Constant][90]); + %469 = reshape(%468, newshape=[50, 32, 2304]); + %470 = split(%469, indices_or_sections=[768, 1536], axis=2); + %471 = %470.0; + %472 = reshape(%471, newshape=[50, 32, 12, 64]); + %473 = transpose(%472, axes=[0, 2, 1, 3]); + %474 = %470.1; + %475 = reshape(%474, newshape=[50, 32, 12, 64]); + %476 = transpose(%475, axes=[0, 2, 3, 1]); + %477 = reshape(%476, newshape=[-1, 64, 32]); + %478 = reshape(%473, newshape=[-1, 32, 64]); + %479 = transpose(%477, axes=[0, 2, 1]); + %480 = nn.batch_matmul(%478, %479, out_dtype="float32", transpose_b=True); + %481 = reshape(%480, newshape=[50, 12, 32, 32]); + %482 = divide(%481, 8f); + %483 = multiply(%482, meta[relay.Constant][91]); + %484 = subtract(%483, meta[relay.Constant][92]); + %485 = nn.softmax(%484, axis=3); + %486 = %470.2; + %487 = reshape(%486, newshape=[50, 32, 12, 64]); + %488 = transpose(%487, axes=[0, 2, 1, 3]); + %489 = reshape(%488, newshape=[-1, 32, 64]); + %490 = reshape(%485, newshape=[-1, 32, 32]); + %491 = transpose(%489, axes=[0, 2, 1]); + %492 = nn.batch_matmul(%490, %491, out_dtype="float32", transpose_b=True); + %493 = reshape(%492, newshape=[50, 12, 32, 64]); + %494 = transpose(%493, axes=[0, 2, 1, 3]); + %495 = reshape(%494, newshape=[50, 32, 768]); + %496 = reshape(%495, newshape=[-1, 768]); + %498 = nn.dense(%496, meta[relay.Constant][86], units=768); + %499 = add(%498, meta[relay.Constant][93]); + %500 = reshape(%499, newshape=[50, 32, 768]); + %501 = add(%455, %500); + %502 = mean(%501, axis=[-1], keepdims=True); + %503 = subtract(%501, %502); + %504 = power(%503, 2f); + %505 = mean(%504, axis=[-1], keepdims=True); + %506 = add(%505, 1e-05f); + %507 = sqrt(%506); + %508 = divide(%503, %507); + %509 = multiply(%508, meta[relay.Constant][96]); + %510 = add(%509, meta[relay.Constant][97]); + %511 = reshape(%510, newshape=[-1, 768]); + %513 = nn.dense(%511, meta[relay.Constant][95], units=3072); + %514 = add(%513, meta[relay.Constant][98]); + %515 = reshape(%514, newshape=[50, 32, 3072]); + %516 = power(%515, 3f); + %517 = multiply(%516, 0.044715f); + %518 = add(%515, %517); + %519 = multiply(%518, 0.797885f); + %520 = tanh(%519); + %521 = multiply(%515, 0.5f); + %522 = add(%520, 1f); + %523 = multiply(%521, %522); + %524 = reshape(%523, newshape=[-1, 3072]); + %526 = nn.dense(%524, meta[relay.Constant][94], units=768); + %527 = add(%526, meta[relay.Constant][99]); + %528 = reshape(%527, newshape=[50, 32, 768]); + %529 = add(%501, %528); + %530 = mean(%529, axis=[-1], keepdims=True); + %531 = subtract(%529, %530); + %532 = power(%531, 2f); + %533 = mean(%532, axis=[-1], keepdims=True); + %534 = add(%533, 1e-05f); + %535 = sqrt(%534); + %536 = divide(%531, %535); + %537 = multiply(%536, meta[relay.Constant][102]); + %538 = add(%537, meta[relay.Constant][103]); + %539 = reshape(%538, newshape=[-1, 768]); + %541 = nn.dense(%539, meta[relay.Constant][101], units=2304); + %542 = add(%541, meta[relay.Constant][104]); + %543 = reshape(%542, newshape=[50, 32, 2304]); + %544 = split(%543, indices_or_sections=[768, 1536], axis=2); + %545 = %544.0; + %546 = reshape(%545, newshape=[50, 32, 12, 64]); + %547 = transpose(%546, axes=[0, 2, 1, 3]); + %548 = %544.1; + %549 = reshape(%548, newshape=[50, 32, 12, 64]); + %550 = transpose(%549, axes=[0, 2, 3, 1]); + %551 = reshape(%550, newshape=[-1, 64, 32]); + %552 = reshape(%547, newshape=[-1, 32, 64]); + %553 = transpose(%551, axes=[0, 2, 1]); + %554 = nn.batch_matmul(%552, %553, out_dtype="float32", transpose_b=True); + %555 = reshape(%554, newshape=[50, 12, 32, 32]); + %556 = divide(%555, 8f); + %557 = multiply(%556, meta[relay.Constant][105]); + %558 = subtract(%557, meta[relay.Constant][106]); + %559 = nn.softmax(%558, axis=3); + %560 = %544.2; + %561 = reshape(%560, newshape=[50, 32, 12, 64]); + %562 = transpose(%561, axes=[0, 2, 1, 3]); + %563 = reshape(%562, newshape=[-1, 32, 64]); + %564 = reshape(%559, newshape=[-1, 32, 32]); + %565 = transpose(%563, axes=[0, 2, 1]); + %566 = nn.batch_matmul(%564, %565, out_dtype="float32", transpose_b=True); + %567 = reshape(%566, newshape=[50, 12, 32, 64]); + %568 = transpose(%567, axes=[0, 2, 1, 3]); + %569 = reshape(%568, newshape=[50, 32, 768]); + %570 = reshape(%569, newshape=[-1, 768]); + %572 = nn.dense(%570, meta[relay.Constant][100], units=768); + %573 = add(%572, meta[relay.Constant][107]); + %574 = reshape(%573, newshape=[50, 32, 768]); + %575 = add(%529, %574); + %576 = mean(%575, axis=[-1], keepdims=True); + %577 = subtract(%575, %576); + %578 = power(%577, 2f); + %579 = mean(%578, axis=[-1], keepdims=True); + %580 = add(%579, 1e-05f); + %581 = sqrt(%580); + %582 = divide(%577, %581); + %583 = multiply(%582, meta[relay.Constant][110]); + %584 = add(%583, meta[relay.Constant][111]); + %585 = reshape(%584, newshape=[-1, 768]); + %587 = nn.dense(%585, meta[relay.Constant][109], units=3072); + %588 = add(%587, meta[relay.Constant][112]); + %589 = reshape(%588, newshape=[50, 32, 3072]); + %590 = power(%589, 3f); + %591 = multiply(%590, 0.044715f); + %592 = add(%589, %591); + %593 = multiply(%592, 0.797885f); + %594 = tanh(%593); + %595 = multiply(%589, 0.5f); + %596 = add(%594, 1f); + %597 = multiply(%595, %596); + %598 = reshape(%597, newshape=[-1, 3072]); + %600 = nn.dense(%598, meta[relay.Constant][108], units=768); + %601 = add(%600, meta[relay.Constant][113]); + %602 = reshape(%601, newshape=[50, 32, 768]); + %603 = add(%575, %602); + %604 = mean(%603, axis=[-1], keepdims=True); + %605 = subtract(%603, %604); + %606 = power(%605, 2f); + %607 = mean(%606, axis=[-1], keepdims=True); + %608 = add(%607, 1e-05f); + %609 = sqrt(%608); + %610 = divide(%605, %609); + %611 = multiply(%610, meta[relay.Constant][116]); + %612 = add(%611, meta[relay.Constant][117]); + %613 = reshape(%612, newshape=[-1, 768]); + %615 = nn.dense(%613, meta[relay.Constant][115], units=2304); + %616 = add(%615, meta[relay.Constant][118]); + %617 = reshape(%616, newshape=[50, 32, 2304]); + %618 = split(%617, indices_or_sections=[768, 1536], axis=2); + %619 = %618.0; + %620 = reshape(%619, newshape=[50, 32, 12, 64]); + %621 = transpose(%620, axes=[0, 2, 1, 3]); + %622 = %618.1; + %623 = reshape(%622, newshape=[50, 32, 12, 64]); + %624 = transpose(%623, axes=[0, 2, 3, 1]); + %625 = reshape(%624, newshape=[-1, 64, 32]); + %626 = reshape(%621, newshape=[-1, 32, 64]); + %627 = transpose(%625, axes=[0, 2, 1]); + %628 = nn.batch_matmul(%626, %627, out_dtype="float32", transpose_b=True); + %629 = reshape(%628, newshape=[50, 12, 32, 32]); + %630 = divide(%629, 8f); + %631 = multiply(%630, meta[relay.Constant][119]); + %632 = subtract(%631, meta[relay.Constant][120]); + %633 = nn.softmax(%632, axis=3); + %634 = %618.2; + %635 = reshape(%634, newshape=[50, 32, 12, 64]); + %636 = transpose(%635, axes=[0, 2, 1, 3]); + %637 = reshape(%636, newshape=[-1, 32, 64]); + %638 = reshape(%633, newshape=[-1, 32, 32]); + %639 = transpose(%637, axes=[0, 2, 1]); + %640 = nn.batch_matmul(%638, %639, out_dtype="float32", transpose_b=True); + %641 = reshape(%640, newshape=[50, 12, 32, 64]); + %642 = transpose(%641, axes=[0, 2, 1, 3]); + %643 = reshape(%642, newshape=[50, 32, 768]); + %644 = reshape(%643, newshape=[-1, 768]); + %646 = nn.dense(%644, meta[relay.Constant][114], units=768); + %647 = add(%646, meta[relay.Constant][121]); + %648 = reshape(%647, newshape=[50, 32, 768]); + %649 = add(%603, %648); + %650 = mean(%649, axis=[-1], keepdims=True); + %651 = subtract(%649, %650); + %652 = power(%651, 2f); + %653 = mean(%652, axis=[-1], keepdims=True); + %654 = add(%653, 1e-05f); + %655 = sqrt(%654); + %656 = divide(%651, %655); + %657 = multiply(%656, meta[relay.Constant][124]); + %658 = add(%657, meta[relay.Constant][125]); + %659 = reshape(%658, newshape=[-1, 768]); + %661 = nn.dense(%659, meta[relay.Constant][123], units=3072); + %662 = add(%661, meta[relay.Constant][126]); + %663 = reshape(%662, newshape=[50, 32, 3072]); + %664 = power(%663, 3f); + %665 = multiply(%664, 0.044715f); + %666 = add(%663, %665); + %667 = multiply(%666, 0.797885f); + %668 = tanh(%667); + %669 = multiply(%663, 0.5f); + %670 = add(%668, 1f); + %671 = multiply(%669, %670); + %672 = reshape(%671, newshape=[-1, 3072]); + %674 = nn.dense(%672, meta[relay.Constant][122], units=768); + %675 = add(%674, meta[relay.Constant][127]); + %676 = reshape(%675, newshape=[50, 32, 768]); + %677 = add(%649, %676); + %678 = mean(%677, axis=[-1], keepdims=True); + %679 = subtract(%677, %678); + %680 = power(%679, 2f); + %681 = mean(%680, axis=[-1], keepdims=True); + %682 = add(%681, 1e-05f); + %683 = sqrt(%682); + %684 = divide(%679, %683); + %685 = multiply(%684, meta[relay.Constant][130]); + %686 = add(%685, meta[relay.Constant][131]); + %687 = reshape(%686, newshape=[-1, 768]); + %689 = nn.dense(%687, meta[relay.Constant][129], units=2304); + %690 = add(%689, meta[relay.Constant][132]); + %691 = reshape(%690, newshape=[50, 32, 2304]); + %692 = split(%691, indices_or_sections=[768, 1536], axis=2); + %693 = %692.0; + %694 = reshape(%693, newshape=[50, 32, 12, 64]); + %695 = transpose(%694, axes=[0, 2, 1, 3]); + %696 = %692.1; + %697 = reshape(%696, newshape=[50, 32, 12, 64]); + %698 = transpose(%697, axes=[0, 2, 3, 1]); + %699 = reshape(%698, newshape=[-1, 64, 32]); + %700 = reshape(%695, newshape=[-1, 32, 64]); + %701 = transpose(%699, axes=[0, 2, 1]); + %702 = nn.batch_matmul(%700, %701, out_dtype="float32", transpose_b=True); + %703 = reshape(%702, newshape=[50, 12, 32, 32]); + %704 = divide(%703, 8f); + %705 = multiply(%704, meta[relay.Constant][133]); + %706 = subtract(%705, meta[relay.Constant][134]); + %707 = nn.softmax(%706, axis=3); + %708 = %692.2; + %709 = reshape(%708, newshape=[50, 32, 12, 64]); + %710 = transpose(%709, axes=[0, 2, 1, 3]); + %711 = reshape(%710, newshape=[-1, 32, 64]); + %712 = reshape(%707, newshape=[-1, 32, 32]); + %713 = transpose(%711, axes=[0, 2, 1]); + %714 = nn.batch_matmul(%712, %713, out_dtype="float32", transpose_b=True); + %715 = reshape(%714, newshape=[50, 12, 32, 64]); + %716 = transpose(%715, axes=[0, 2, 1, 3]); + %717 = reshape(%716, newshape=[50, 32, 768]); + %718 = reshape(%717, newshape=[-1, 768]); + %720 = nn.dense(%718, meta[relay.Constant][128], units=768); + %721 = add(%720, meta[relay.Constant][135]); + %722 = reshape(%721, newshape=[50, 32, 768]); + %723 = add(%677, %722); + %724 = mean(%723, axis=[-1], keepdims=True); + %725 = subtract(%723, %724); + %726 = power(%725, 2f); + %727 = mean(%726, axis=[-1], keepdims=True); + %728 = add(%727, 1e-05f); + %729 = sqrt(%728); + %730 = divide(%725, %729); + %731 = multiply(%730, meta[relay.Constant][138]); + %732 = add(%731, meta[relay.Constant][139]); + %733 = reshape(%732, newshape=[-1, 768]); + %735 = nn.dense(%733, meta[relay.Constant][137], units=3072); + %736 = add(%735, meta[relay.Constant][140]); + %737 = reshape(%736, newshape=[50, 32, 3072]); + %738 = power(%737, 3f); + %739 = multiply(%738, 0.044715f); + %740 = add(%737, %739); + %741 = multiply(%740, 0.797885f); + %742 = tanh(%741); + %743 = multiply(%737, 0.5f); + %744 = add(%742, 1f); + %745 = multiply(%743, %744); + %746 = reshape(%745, newshape=[-1, 3072]); + %748 = nn.dense(%746, meta[relay.Constant][136], units=768); + %749 = add(%748, meta[relay.Constant][141]); + %750 = reshape(%749, newshape=[50, 32, 768]); + %751 = add(%723, %750); + %752 = mean(%751, axis=[-1], keepdims=True); + %753 = subtract(%751, %752); + %754 = power(%753, 2f); + %755 = mean(%754, axis=[-1], keepdims=True); + %756 = add(%755, 1e-05f); + %757 = sqrt(%756); + %758 = divide(%753, %757); + %759 = multiply(%758, meta[relay.Constant][144]); + %760 = add(%759, meta[relay.Constant][145]); + %761 = reshape(%760, newshape=[-1, 768]); + %763 = nn.dense(%761, meta[relay.Constant][143], units=2304); + %764 = add(%763, meta[relay.Constant][146]); + %765 = reshape(%764, newshape=[50, 32, 2304]); + %766 = split(%765, indices_or_sections=[768, 1536], axis=2); + %767 = %766.0; + %768 = reshape(%767, newshape=[50, 32, 12, 64]); + %769 = transpose(%768, axes=[0, 2, 1, 3]); + %770 = %766.1; + %771 = reshape(%770, newshape=[50, 32, 12, 64]); + %772 = transpose(%771, axes=[0, 2, 3, 1]); + %773 = reshape(%772, newshape=[-1, 64, 32]); + %774 = reshape(%769, newshape=[-1, 32, 64]); + %775 = transpose(%773, axes=[0, 2, 1]); + %776 = nn.batch_matmul(%774, %775, out_dtype="float32", transpose_b=True); + %777 = reshape(%776, newshape=[50, 12, 32, 32]); + %778 = divide(%777, 8f); + %779 = multiply(%778, meta[relay.Constant][147]); + %780 = subtract(%779, meta[relay.Constant][148]); + %781 = nn.softmax(%780, axis=3); + %782 = %766.2; + %783 = reshape(%782, newshape=[50, 32, 12, 64]); + %784 = transpose(%783, axes=[0, 2, 1, 3]); + %785 = reshape(%784, newshape=[-1, 32, 64]); + %786 = reshape(%781, newshape=[-1, 32, 32]); + %787 = transpose(%785, axes=[0, 2, 1]); + %788 = nn.batch_matmul(%786, %787, out_dtype="float32", transpose_b=True); + %789 = reshape(%788, newshape=[50, 12, 32, 64]); + %790 = transpose(%789, axes=[0, 2, 1, 3]); + %791 = reshape(%790, newshape=[50, 32, 768]); + %792 = reshape(%791, newshape=[-1, 768]); + %794 = nn.dense(%792, meta[relay.Constant][142], units=768); + %795 = add(%794, meta[relay.Constant][149]); + %796 = reshape(%795, newshape=[50, 32, 768]); + %797 = add(%751, %796); + %798 = mean(%797, axis=[-1], keepdims=True); + %799 = subtract(%797, %798); + %800 = power(%799, 2f); + %801 = mean(%800, axis=[-1], keepdims=True); + %802 = add(%801, 1e-05f); + %803 = sqrt(%802); + %804 = divide(%799, %803); + %805 = multiply(%804, meta[relay.Constant][152]); + %806 = add(%805, meta[relay.Constant][153]); + %807 = reshape(%806, newshape=[-1, 768]); + %809 = nn.dense(%807, meta[relay.Constant][151], units=3072); + %810 = add(%809, meta[relay.Constant][154]); + %811 = reshape(%810, newshape=[50, 32, 3072]); + %812 = power(%811, 3f); + %813 = multiply(%812, 0.044715f); + %814 = add(%811, %813); + %815 = multiply(%814, 0.797885f); + %816 = tanh(%815); + %817 = multiply(%811, 0.5f); + %818 = add(%816, 1f); + %819 = multiply(%817, %818); + %820 = reshape(%819, newshape=[-1, 3072]); + %822 = nn.dense(%820, meta[relay.Constant][150], units=768); + %823 = add(%822, meta[relay.Constant][155]); + %824 = reshape(%823, newshape=[50, 32, 768]); + %825 = add(%797, %824); + %826 = mean(%825, axis=[-1], keepdims=True); + %827 = subtract(%825, %826); + %828 = power(%827, 2f); + %829 = mean(%828, axis=[-1], keepdims=True); + %830 = add(%829, 1e-05f); + %831 = sqrt(%830); + %832 = divide(%827, %831); + %833 = multiply(%832, meta[relay.Constant][158]); + %834 = add(%833, meta[relay.Constant][159]); + %835 = reshape(%834, newshape=[-1, 768]); + %837 = nn.dense(%835, meta[relay.Constant][157], units=2304); + %838 = add(%837, meta[relay.Constant][160]); + %839 = reshape(%838, newshape=[50, 32, 2304]); + %840 = split(%839, indices_or_sections=[768, 1536], axis=2); + %841 = %840.0; + %842 = reshape(%841, newshape=[50, 32, 12, 64]); + %843 = transpose(%842, axes=[0, 2, 1, 3]); + %844 = %840.1; + %845 = reshape(%844, newshape=[50, 32, 12, 64]); + %846 = transpose(%845, axes=[0, 2, 3, 1]); + %847 = reshape(%846, newshape=[-1, 64, 32]); + %848 = reshape(%843, newshape=[-1, 32, 64]); + %849 = transpose(%847, axes=[0, 2, 1]); + %850 = nn.batch_matmul(%848, %849, out_dtype="float32", transpose_b=True); + %851 = reshape(%850, newshape=[50, 12, 32, 32]); + %852 = divide(%851, 8f); + %853 = multiply(%852, meta[relay.Constant][161]); + %854 = subtract(%853, meta[relay.Constant][162]); + %855 = nn.softmax(%854, axis=3); + %856 = %840.2; + %857 = reshape(%856, newshape=[50, 32, 12, 64]); + %858 = transpose(%857, axes=[0, 2, 1, 3]); + %859 = reshape(%858, newshape=[-1, 32, 64]); + %860 = reshape(%855, newshape=[-1, 32, 32]); + %861 = transpose(%859, axes=[0, 2, 1]); + %862 = nn.batch_matmul(%860, %861, out_dtype="float32", transpose_b=True); + %863 = reshape(%862, newshape=[50, 12, 32, 64]); + %864 = transpose(%863, axes=[0, 2, 1, 3]); + %865 = reshape(%864, newshape=[50, 32, 768]); + %866 = reshape(%865, newshape=[-1, 768]); + %868 = nn.dense(%866, meta[relay.Constant][156], units=768); + %869 = add(%868, meta[relay.Constant][163]); + %870 = reshape(%869, newshape=[50, 32, 768]); + %871 = add(%825, %870); + %872 = mean(%871, axis=[-1], keepdims=True); + %873 = subtract(%871, %872); + %874 = power(%873, 2f); + %875 = mean(%874, axis=[-1], keepdims=True); + %876 = add(%875, 1e-05f); + %877 = sqrt(%876); + %878 = divide(%873, %877); + %879 = multiply(%878, meta[relay.Constant][166]); + %880 = add(%879, meta[relay.Constant][167]); + %881 = reshape(%880, newshape=[-1, 768]); + %883 = nn.dense(%881, meta[relay.Constant][165], units=3072); + %884 = add(%883, meta[relay.Constant][168]); + %885 = reshape(%884, newshape=[50, 32, 3072]); + %886 = power(%885, 3f); + %887 = multiply(%886, 0.044715f); + %888 = add(%885, %887); + %889 = multiply(%888, 0.797885f); + %890 = tanh(%889); + %891 = multiply(%885, 0.5f); + %892 = add(%890, 1f); + %893 = multiply(%891, %892); + %894 = reshape(%893, newshape=[-1, 3072]); + %896 = nn.dense(%894, meta[relay.Constant][164], units=768); + %897 = add(%896, meta[relay.Constant][169]); + %898 = reshape(%897, newshape=[50, 32, 768]); + %899 = add(%871, %898); + %900 = mean(%899, axis=[-1], keepdims=True); + %901 = subtract(%899, %900); + %902 = power(%901, 2f); + %903 = mean(%902, axis=[-1], keepdims=True); + %904 = add(%903, 1e-05f); + %905 = sqrt(%904); + %906 = divide(%901, %905); + %907 = multiply(%906, meta[relay.Constant][170]); + %908 = add(%907, meta[relay.Constant][171]); + %909 = transpose(%25, axes=[0, 2, 1, 3]); + %910 = expand_dims(%909, axis=0); + %911 = expand_dims(%38, axis=0); + %912 = (%910, %911); + %913 = transpose(%100, axes=[0, 2, 1, 3]); + %914 = expand_dims(%913, axis=0); + %915 = expand_dims(%113, axis=0); + %916 = (%914, %915); + %917 = transpose(%176, axes=[0, 2, 1, 3]); + %918 = expand_dims(%917, axis=0); + %919 = expand_dims(%189, axis=0); + %920 = (%918, %919); + %921 = transpose(%252, axes=[0, 2, 1, 3]); + %922 = expand_dims(%921, axis=0); + %923 = expand_dims(%265, axis=0); + %924 = (%922, %923); + %925 = transpose(%327, axes=[0, 2, 1, 3]); + %926 = expand_dims(%925, axis=0); + %927 = expand_dims(%340, axis=0); + %928 = (%926, %927); + %929 = transpose(%401, axes=[0, 2, 1, 3]); + %930 = expand_dims(%929, axis=0); + %931 = expand_dims(%414, axis=0); + %932 = (%930, %931); + %933 = transpose(%475, axes=[0, 2, 1, 3]); + %934 = expand_dims(%933, axis=0); + %935 = expand_dims(%488, axis=0); + %936 = (%934, %935); + %937 = transpose(%549, axes=[0, 2, 1, 3]); + %938 = expand_dims(%937, axis=0); + %939 = expand_dims(%562, axis=0); + %940 = (%938, %939); + %941 = transpose(%623, axes=[0, 2, 1, 3]); + %942 = expand_dims(%941, axis=0); + %943 = expand_dims(%636, axis=0); + %944 = (%942, %943); + %945 = transpose(%697, axes=[0, 2, 1, 3]); + %946 = expand_dims(%945, axis=0); + %947 = expand_dims(%710, axis=0); + %948 = (%946, %947); + %949 = transpose(%771, axes=[0, 2, 1, 3]); + %950 = expand_dims(%949, axis=0); + %951 = expand_dims(%784, axis=0); + %952 = (%950, %951); + %953 = transpose(%845, axes=[0, 2, 1, 3]); + %954 = expand_dims(%953, axis=0); + %955 = expand_dims(%858, axis=0); + %956 = (%954, %955); + %957 = reshape(%908, newshape=[1, 50, 32, 768]); + %958 = concatenate(%912); + %959 = concatenate(%916); + %960 = concatenate(%920); + %961 = concatenate(%924); + %962 = concatenate(%928); + %963 = concatenate(%932); + %964 = concatenate(%936); + %965 = concatenate(%940); + %966 = concatenate(%944); + %967 = concatenate(%948); + %968 = concatenate(%952); + %969 = concatenate(%956); + (%957, %958, %959, %960, %961, %962, %963, %964, %965, %966, %967, %968, %969) + } + """, + "from_string", + None, + metatable + ) + + return {'input_shapes': {"x": [1, 50, 32]}, 'input_dtypes': {"x": 'int64'}, 'mod': mod, 'params': None} + + +def arg_for(shape, dtype, device): + return tvm.nd.array( + np.random.uniform(-1.0, 1.0, size=shape).astype(dtype), device=device) + + +def describe_onnx(filename): + """Returns the form of run to invoke ONNX model at model_file. + Note that ? (ie unknown) shape dimensions must be changed to concrete dimensions + which are consistent with the overall model.""" + onnx_model = onnx.load(MODEL_PREFIX + filename) + input_shapes = {} + input_dtypes = {} + initializer_names = [n.name for n in onnx_model.graph.initializer] + for input_info in onnx_model.graph.input: + if input_info.name not in initializer_names: + _, shape, dtype, _ = tvm.relay.frontend.onnx.get_info(input_info) + if dtype is None: + raise ValueError( + f"Unknown dtype on input '{input_info.name}' is not supported." + ) + input_shapes.update({input_info.name: shape}) + input_dtypes.update({input_info.name: dtype}) + print(f"{{'filename': {filename}, 'input_shapes': {input_shapes}, 'input_dtypes': {input_dtypes}}}") + + +def from_onnx(model): + logging.info("-------------------- BEGIN ONNX IMPORT --------------------") + + filename = MODEL_PREFIX + model['filename'] + logging.info(f"Loading ONNX model from {filename}") + + onnx_model = onnx.load(filename) + logging.info(f"Loaded model from {filename}") + + mod, params = tvm.relay.frontend.from_onnx(onnx_model, model['input_shapes'], freeze_params=True) + mod = tvm.relay.transform.InferType()(mod) + logging.info("-------------------- END ONNX IMPORT --------------------") + + logging.info(f"Imported model:\n{mod}") + logging.info(f"Params:\n{params}") + + return {'input_shapes': model['input_shapes'], 'input_dtypes': model['input_dtypes'], 'mod': mod, 'params': params} + + +def compile_and_benchmark(model, targets, dev): + exe = tvm.relay.vm.compile(model['mod'], target=targets, params=model['params']) + vm = tvm.runtime.vm.VirtualMachine(exe, dev) + args = { + input_name: arg_for(model['input_shapes'][input_name], model['input_dtypes'][input_name], dev) + for input_name in model['input_shapes'].keys() + } + profile_result = vm.benchmark( + dev, + func_name="main", + number=1, + repeat=1, + min_repeat_ms=10, + **args, + ).results + logging.info("time: {}us".format(np.mean(profile_result) * 1e6)) + + +def trt_options(): + return {"use_implicit_batch": False, + "max_workspace_size": 1 << 30, + "remove_no_mac_subgraphs": False, + "use_fp16": False, + "use_uint8": False} + + +def collage(model): + host_target = tvm.target.Target("llvm") + targets = [] + targets += [tvm.target.Target("cuda", host_target)] + targets += [tvm.target.Target("cuda -compiler=cutlass", host_target)] + targets += [tvm.target.Target("cuda -compiler=tensorrt", host_target)] + targets += [tvm.target.Target("cuda -compiler=cublas", host_target)] + targets += [tvm.target.Target("cuda -compiler=cudnn", host_target)] + dev = tvm.device("cuda") + with tvm.transform.PassContext(config={"relay.fallback_device_type": dev.device_type, + "relay.collage.enable_collage": True, + "relay.collage.autotvm_log_filename": TUNING_LOG, + "relay.ext.tensorrt.options": trt_options()}): + compile_and_benchmark(model, targets, dev) + + +def optional_tuning_records(log_filename): + if log_filename == "": + return tvm.autotvm.tvm.autotvm.task.FallbackContext() + else: + return tvm.autotvm.task.ApplyHistoryBest(log_filename) + + +def just_trt(model): + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("cuda", host_target) + with optional_tuning_records(TUNING_LOG): + targets = [generic_target] + dev = tvm.device(generic_target.kind.device_type) + options = trt_options() + mod, options = tvm.relay.op.contrib.partition_for_tensorrt(mod=model['mod'], params=model['params'], + **options) + logging.info("-------------- BEGIN PARTITIONED --------------") + logging.info(mod) + logging.info("-------------- END PARTITIONED ----------------") + with tvm.transform.PassContext(config={"relay.fallback_device_type": generic_target.kind.device_type, + "relay.ext.tensorrt.options": options}): + compile_and_benchmark(model, targets, dev) + + +def just_cutlass(model): + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("cuda", host_target) + with optional_tuning_records(TUNING_LOG): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"], + config={"relay.fallback_device_type": generic_target.kind.device_type}): + targets = [generic_target] + dev = tvm.device(generic_target.kind.device_type) + mod = tvm.relay.op.contrib.partition_for_cutlass(model['mod'], model['params']) + logging.info("-------------- BEGIN PARTITIONED --------------") + logging.info(mod) + logging.info("-------------- END PARTITIONED ----------------") + compile_and_benchmark(model, targets, dev) + + +def just_tvm(model): + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("cuda", host_target) + with optional_tuning_records(TUNING_LOG): + with tvm.transform.PassContext(config={"relay.fallback_device_type": generic_target.kind.device_type}): + targets = [generic_target] + dev = tvm.device(generic_target.kind.device_type) + mod = model['mod'] + logging.info("-------------- BEGIN MODULE --------------") + logging.info(mod) + logging.info("-------------- END MODULE ----------------") + compile_and_benchmark(model, targets, dev) + + +if __name__ == "__main__": + collage(gpt2()) diff --git a/tests/python/relay/collage/test_sub_graph.py b/tests/python/relay/collage/test_sub_graph.py new file mode 100644 index 0000000000000..df4a0257dc5d4 --- /dev/null +++ b/tests/python/relay/collage/test_sub_graph.py @@ -0,0 +1,374 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import logging + +logging.basicConfig(level=logging.INFO) + +capture_index_in_spans = tvm._ffi.get_global_func("relay.collage.capture_index_in_spans") +partition_on_indexes_for_testing = tvm._ffi.get_global_func("relay.collage.partition_on_indexes_for_testing") + + +def print_with_indexes(mod): + mod = capture_index_in_spans()(mod) + print(mod) + + +def process(mod, max_outputs, allow_taps, indexes, labels=None): + mod = tvm.relay.transform.InferType()(mod) + mod = capture_index_in_spans()(mod) + mod = partition_on_indexes_for_testing(max_outputs, allow_taps, indexes, labels)(mod) + return mod + + +def assert_eq(in_mod, expected_mod, actual_mod): + if not tvm.ir.structural_equal(actual_mod, expected_mod, True): + # Print everything in full so we can see what's going on when things fail. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod, True) + + +def test_single_op(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); // node 7 + subtract(%0, %1) + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = (fn(%x, %y) { add(%x, %y) })(%c, %d); + subtract(%0, %1) + } + """) + + assert_eq(input(), expected(), process(input(), 1, False, [7])) + + +def test_multi_output(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); // node 6 + %1 = add(%c, %d); // node 7 + subtract(%0, %1) + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = (fn(%w, %x, %y, %z) { (add(%y, %z), add(%w, %x)) })(%c, %d, %a, %b); + %1 = %0.0; + %2 = %0.1; + subtract(%1, %2) + } + """) + + # No rewrite since 2 outputs + assert_eq(input(), input(), process(input(), 1, False, [6, 7])) + # Rewrite + assert_eq(input(), expected(), process(input(), 2, False, [6, 7])) + + +def test_classic_conv2d_add_relu(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32], + %c: Tensor[(5, 2, 28, 28), float32], %d: Tensor[(5, 2, 28, 28), float32]) { + %0 = nn.conv2d(%a, %b); // node 8 + %1 = add(%0, %c); // node 9 + %2 = nn.relu(%1); // node 10 + subtract(%2, %d) + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32], + %c: Tensor[(5, 2, 28, 28), float32], %d: Tensor[(5, 2, 28, 28), float32]) { + %2 = (fn(%x, %y, %z) { + %0 = nn.conv2d(%x, %y); + %1 = add(%0, %z); + nn.relu(%1) + })(%a, %b, %c); + subtract(%2, %d) + } + """) + + assert_eq(input(), expected(), process(input(), 1, False, [8, 9, 10])) + + +def test_diamond_single_output(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %0 = nn.conv2d(%a, %b, padding=[0, 0, 0, 0]); // node 5 + %1 = nn.relu(%0); // node 6 + %2 = nn.relu(%1); // node 7 + %3 = nn.leaky_relu(%0, alpha=0f); // node 9 + add(%2, %3) // node 10 + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + (fn (%x: Tensor[(5, 3, 32, 32), float32], %y: Tensor[(2, 3, 5, 5), float32]) { + %0 = nn.conv2d(%x, %y, padding=[0, 0, 0, 0]); + %1 = nn.relu(%0); + %2 = nn.relu(%1); + %3 = nn.leaky_relu(%0, alpha=0f); + add(%2, %3) + })(%a, %b) + } + """) + + assert_eq(input(), expected(), process(input(), 1, False, [5, 6, 7, 9, 10])) + + +def test_diamond_multi_output(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %0 = nn.conv2d(%a, %b, padding=[0, 0, 0, 0]); // node 5 + %1 = nn.relu(%0); // node 6 + %2 = nn.relu(%1); // node 7 + %3 = nn.leaky_relu(%0, alpha=0f); // node 9 + add(%2, %3) + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %4 = (fn (%x: Tensor[(5, 3, 32, 32), float32], %y: Tensor[(2, 3, 5, 5), float32]) { + %0 = nn.conv2d(%x, %y, padding=[0, 0, 0, 0]); + %1 = nn.relu(%0); + %2 = nn.relu(%1); + %3 = nn.leaky_relu(%0, alpha=0f); + (%2, %3) + })(%a, %b); + %5 = %4.0; + %6 = %4.1; + add(%5, %6) + } + """) + + assert_eq(input(), expected(), process(input(), 2, False, [5, 6, 7, 9])) + + +def test_with_tap(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %0 = nn.conv2d(%a, %b, padding=[0, 0, 0, 0]); // node 5 + %1 = nn.relu(%0); // node 6 + add(%1, %0) + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %2 = (fn (%x, %y) { + %0 = nn.conv2d(%x, %y, padding=[0, 0, 0, 0]); + %1 = nn.relu(%0); + (%0, %1) + })(%a, %b); + %3 = %2.1; + %4 = %2.0; + add(%3, %4) + } + """) + + # No rewrite since has tap + assert_eq(input(), input(), process(input(), 2, False, [5, 6])) + # Rewrite + assert_eq(input(), expected(), process(input(), 2, True, [5, 6])) + + +def test_no_cycles(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); // node 3 + %1 = add(%0, %b); + add(%1, %b) // node 5 + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + (fn(%x, %y) { + %0 = add(%x, %y); + %1 = add(%0, %y); + add(%1, %y) + })(%a, %b) + } + """) + + # No rewrite since would create cycle + assert_eq(input(), input(), process(input(), 2, False, [3, 5])) + # No cycle + assert_eq(input(), expected(), process(input(), 2, False, [3, 4, 5])) + + +def test_labels_direct_connection(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) { + %0 = nn.relu(%a); // node 3 + %1 = nn.relu(%0); // node 4 + %2 = nn.relu(%1); // node 5 + %3 = nn.relu(%1); // node 6 + %4 = add(%2, %3); // node 7 + %5 = nn.relu(%4); // node 8 + %6 = nn.relu(%4); // node 9 + %7 = add(%5, %6); // node 10 + nn.relu(%7) // node 11 + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) { + (fn(%x) { + %0 = nn.relu(%x); + %4 = (fn(%y, Composite="a") { + %1 = nn.relu(%y); + %2 = nn.relu(%1); + %3 = nn.relu(%1); + add(%2, %3) + })(%0); + %7 = (fn(%z, Composite="b") { + %5 = nn.relu(%z); + %6 = nn.relu(%z); + add(%5, %6) + })(%4); + nn.relu(%7) + })(%a) + } + """) + + assert_eq(input(), expected(), process(input(), 1, False, + [3, 4, 5, 6, 7, 8, 9, 10, 11], + ["", "a", "a", "a", "a", "b", "b", "b", ""])) + + +def test_labels_nested_tap(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) { + %0 = nn.relu(%a); // node 3 + %1 = nn.relu(%0); // node 4 + %2 = nn.relu(%1); // node 5 + %3 = nn.relu(%1); // node 6 + %4 = add(%2, %3); // node 7 + %5 = nn.relu(%4); // node 8 + %6 = nn.relu(%4); // node 9 + %7 = add(%5, %6); // node 10 + add(%2, %7) // node 11 + } + """) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) { + %0 = nn.relu(%a); + %9 = (fn(%x) { + %5 = (fn(%y, Composite="a") { + %1 = nn.relu(%y); + %2 = nn.relu(%1); + %3 = nn.relu(%1); + %4 = add(%2, %3); + (%2, %4) + })(%x); + %8 = (fn(%z, Composite="b") { + %6 = nn.relu(%z); + %7 = nn.relu(%z); + add(%6, %7) + })(%5.1); + (%5.0, %8) + })(%0); + add(%9.0, %9.1) + } + """) + + assert_eq(input(), expected(), process(input(), 2, True, + [4, 5, 6, 7, 8, 9, 10], + ["a", "a", "a", "a", "b", "b", "b"])) + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 74e03f6a97551..96f363a84ad89 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=unused-wildcard-import import numpy as np -import pytest import tvm from tvm import relay @@ -601,6 +600,72 @@ def test_match_fake_diamond(): assert not diamond.match(out) +def test_at_most_one_parent(): + # Pattern + P = is_op("nn.conv2d")(wildcard(), wildcard()) # 'parent' + I = is_op("nn.relu")(wildcard()) # 'intermediate' ('path' in the code) + C = is_op("add")(wildcard(), wildcard()) # 'child' + pattern = dominates(P, I, C) + + # n6(P) + # / \ + # n7 \ + # / \ + # n8(P) n10(I) + # \ / + # n9(I) / + # \ / + # n11(C) + + x = relay.var("x") + w = relay.var("w") + n6 = relay.op.nn.conv2d(x, w) # matches P + n7 = relay.op.tanh(n6) # does not match I + n8 = relay.op.nn.conv2d(n7, w) # matches P + n9 = relay.op.nn.relu(n8) # matches I + n10 = relay.op.nn.relu(n6) # matches I + n11 = relay.add(n9, n10) # matches C + + # Does not match: Can't match the parent pattern P at both 8 and 6. + # Note that if we did allow P to be used twice the implementation would + # need to be changed to not 'jump over' n7. + assert not pattern.match(n11) + +def partition(pattern, x): + return relay.transform.InferType()(tvm.IRModule.from_expr(pattern.partition(x))) + +def fuse(x): + return relay.transform.FuseOps(fuse_opt_level=2)(relay.transform.InferType()(tvm.IRModule.from_expr(x))) + +def test_parallel_injective(): + # Pattern + P = is_op("add")(wildcard(), wildcard()) # 'parent' + I = is_op("squeeze")(wildcard()) | is_op("transpose")(wildcard()) # 'intermediate' ('path' in the code) + C = is_op("left_shift")(wildcard(), wildcard()) # 'child' + pattern = dominates(P, I, C) + + # + # n5(P) + # / \ + # n6(I) n8(I) + # \ / + # n9(C) + # + + x = relay.var("x", shape=(10, 20)) + n5 = relay.add(x, relay.const(1, "float32")) + n6 = relay.squeeze(n5) + n8 = relay.transpose(n5, axes=[0, 1]) + n9 = relay.left_shift(n6, n8) + + assert pattern.match(n9) + + print(partition(pattern, n9)) + + print(fuse(n9)) + + + def test_match_dominator(): # Pattern is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()) @@ -900,7 +965,7 @@ def __init__(self): self.eps = is_constant() self.pattern = ( - self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + self.beta + self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + self.beta ) def callback(self, pre, post, node_map): @@ -1183,8 +1248,8 @@ def test_double_partition(): [inpf, weightf, biasf], relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)), ) - .with_attr("Composite", "conv_bias_relu") - .with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + .with_attr("Composite", "conv_bias_relu") + .with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") ) inpf = relay.var("input") weightf = relay.var("weight") @@ -1193,8 +1258,8 @@ def test_double_partition(): relay.Function( [inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf) ) - .with_attr("Composite", "conv_bias") - .with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_") + .with_attr("Composite", "conv_bias") + .with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_") ) expected = func1(func0(x, w, b), w2, b2) @@ -1400,6 +1465,17 @@ def test_partition_overused(): assert pattern.partition(out) == out +def test_partition_not_overused(): + pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard())) + + x = relay.var("input") + w = relay.var("weight") + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + out = relu + relu + + assert pattern.partition(out) == out + def test_partition_fuzzy_tuple(): x = relay.var("x") @@ -1427,7 +1503,6 @@ def concat(*args): def test_partition_fuzzy_function_args(): - func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + wildcard() x = relay.var("x") y = relay.var("y") @@ -1760,4 +1835,7 @@ def callback(self, pre, post, node_map): if __name__ == "__main__": - pytest.main([__file__]) + import sys + import pytest + sys.exit(pytest.main([__file__] + sys.argv[1:])) + diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index cacce5603e5f9..7399979d95c3c 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -15,14 +15,41 @@ # specific language governing permissions and limitations # under the License. import numpy as np -import pytest import tvm from tvm import relay from tvm.relay import transform -from tvm.relay.testing import run_opt_pass import tvm.testing -import tvm.topi.testing + + +def assert_eq(in_mod, expected_mod, actual_mod): + if not tvm.ir.structural_equal(actual_mod, expected_mod, True): + # Print everything in full so we can see what's going on when things fail. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod, True) + + +def process_and_compare(input, expected): + in_mod = tvm.IRModule.from_expr(input) + in_mod = relay.transform.InferType()(in_mod) + + no_fusion_mod = transform.FuseOps(fuse_opt_level=0)(in_mod) + assert not relay.analysis.free_vars(no_fusion_mod["main"]) + + expected_mod = tvm.IRModule.from_expr(expected) + expected_mod = relay.transform.InferType()(expected_mod) + # Yes, really. Somewhere in the type checker a constant is getting rewritten from int64 to int32. + expected_mod = relay.transform.InferType()(expected_mod) + + actual_mod = transform.FuseOps(fuse_opt_level=1)(in_mod) + assert not relay.analysis.free_vars(actual_mod["main"]) + assert_eq(in_mod, expected_mod, actual_mod) def test_fuse_simple(): @@ -46,16 +73,14 @@ def expected(): y = relay.Call(f1, [x]) return relay.Function([x], y) - z = before() - zz = run_opt_pass(z, transform.FuseOps()) - after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + process_and_compare(before(), expected()) def test_conv2d_fuse(): """Test fusion case of conv2d""" + dshape = (1, 16, 64, 64) - def before(dshape): + def before(): x = relay.var("x", shape=dshape) x = relay.add(x, relay.const(1, "float32")) y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=16) @@ -69,7 +94,7 @@ def before(dshape): z = relay.add(z2, z3) return relay.Function(relay.analysis.free_vars(z), z) - def expected(dshape): + def expected(): # segment 0 x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) @@ -110,17 +135,15 @@ def expected(dshape): z = z3 return relay.Function(relay.analysis.free_vars(z), z) - dshape = (1, 16, 64, 64) - z = before(dshape) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) - after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": True}): + process_and_compare(before(), expected()) def test_concatenate(): """Test fusion case involving concat op and Tuple node""" + dshape = (1, 16, 64, 64) - def before(dshape): + def before(): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") @@ -128,7 +151,7 @@ def before(dshape): out = relay.add(concat, relay.const(1, "float32")) return relay.Function(relay.analysis.free_vars(out), out) - def expected(dshape): + def expected(): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) @@ -147,27 +170,21 @@ def expected(dshape): z = relay.Call(f1, [y, x]) return relay.Function([x], z) - dshape = (1, 16, 64, 64) - z = before(dshape) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) - assert not relay.analysis.free_vars(zz) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) - assert not relay.analysis.free_vars(zz) - after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + process_and_compare(before(), expected()) def test_tuple_root(): """Test fusion case where Tuple node is the root in its group""" + dshape = (1, 16, 64, 64) - def before(dshape): + def before(): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") out = relay.Tuple((upsampled, x)) return relay.Function(relay.analysis.free_vars(out), out) - def expected(dshape): + def expected(): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) @@ -184,25 +201,20 @@ def expected(dshape): tup = relay.Tuple((z, x)) return relay.Function([x], tup) - dshape = (1, 16, 64, 64) - z = before(dshape) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) - assert not relay.analysis.free_vars(zz) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) - assert not relay.analysis.free_vars(zz) - after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + process_and_compare(before(), expected()) def test_stop_fusion(): - def before(dshape): + dshape = (10, 20) + + def before(): x = relay.var("x", shape=dshape) y = relay.add(x, relay.const(1, "float32")) y = relay.annotation.stop_fusion(y) z = relay.exp(y) return relay.Function([x], z) - def expected(dshape): + def expected(): x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) f1 = relay.Function([x], y) @@ -218,15 +230,14 @@ def expected(dshape): z = relay.Call(f2, [y]) return relay.Function([x], z) - dshape = (10, 20) - z = before(dshape) - zz = run_opt_pass(z, transform.FuseOps()) - after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + process_and_compare(before(), expected()) def test_fuse_myia_regression(): - def before(dshape, dtype): + dshape = () + dtype = "int64" + + def before(): x = relay.var("x", shape=dshape, dtype=dtype) y = relay.var("y", shape=dshape, dtype=dtype) sb = relay.ScopeBuilder() @@ -236,7 +247,7 @@ def before(dshape, dtype): sb.ret(relay.Function([], y)) return relay.Function([x, y], relay.Call(sb.get(), [])) - def expected(dshape, dtype): + def expected(): x = relay.var("x", shape=dshape, dtype=dtype) y = relay.var("y", shape=dshape, dtype=dtype) sb = relay.ScopeBuilder() @@ -250,16 +261,13 @@ def expected(dshape, dtype): sb.ret(relay.Function([], y)) return relay.Function([x, y], relay.Call(sb.get(), [])) - dshape = () - dtype = "int64" - f = before(dshape, dtype) - zz = run_opt_pass(f, transform.FuseOps()) - after = run_opt_pass(expected(dshape, dtype), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + process_and_compare(before(), expected()) def test_fuse_tuple_get_elemwise(): - def before(dim): + dim = 10 + + def before(): X = relay.var("X", shape=(1, dim)) W = relay.var("W", shape=(3 * dim, dim)) matmul = relay.nn.dense(X, W) @@ -267,7 +275,7 @@ def before(dim): out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) return relay.Function([X, W], out) - def expected(dim): + def expected(): p0 = relay.var("p0", shape=(1, dim)) p1 = relay.var("p1", shape=(3 * dim, dim)) matmul = relay.nn.dense(p0, p1) @@ -286,25 +294,20 @@ def expected(dim): z = relay.Call(f1, [y]) return relay.Function([X, W], z) - dim = 10 - z = before(dim) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) - assert not relay.analysis.free_vars(zz) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) - assert not relay.analysis.free_vars(zz) - after = run_opt_pass(expected(dim), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + process_and_compare(before(), expected()) def test_tuple_get_root(): - def before(dim): + dim = 10 + + def before(): X = relay.var("X", shape=(1, 3 * dim)) W = relay.var("W", shape=(dim, dim)) splitted = relay.split(X, indices_or_sections=3, axis=1) out = relay.nn.dense(splitted[0], W) return relay.Function([X, W], out) - def expected(dim): + def expected(): p0 = relay.var("p0", shape=(1, 3 * dim)) splitted = relay.split(p0, indices_or_sections=3, axis=1) out = splitted[0] @@ -323,28 +326,14 @@ def expected(dim): z = relay.Call(f1, [y, W]) return relay.Function([X, W], z) - dim = 10 - z = before(dim) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) - assert not relay.analysis.free_vars(zz) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) - assert not relay.analysis.free_vars(zz) - after = run_opt_pass(expected(dim), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) - - -def fuse0(mod): - mod = relay.transform.InferType()(mod) - return relay.transform.FuseOps(fuse_opt_level=0)(mod) - - -def fuse2(mod): - mod = relay.transform.InferType()(mod) - return relay.transform.FuseOps(fuse_opt_level=2)(mod) + process_and_compare(before(), expected()) def test_tuple_intermediate(): - def before(x): + dshape = (1, 16, 64, 64) + + def before(): + x = relay.var("x", shape=dshape) inj = relay.squeeze(x) y1 = relay.add(inj, relay.const(1, "float32")) tmp = relay.squeeze(inj) @@ -356,24 +345,19 @@ def before(x): out = relay.add(out_inj, relay.const(1, "float32")) return relay.Function(relay.analysis.free_vars(out), out) - def expected(p0): - f0 = before(p0) + def expected(): + f0 = before() f1 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - x = relay.var("x", shape=dshape) - y = relay.Call(f1, [x]) - return relay.Function([x], y) + xx = relay.var("xx", shape=dshape) + y = relay.Call(f1, [xx]) + return relay.Function([xx], y) - dshape = (1, 16, 64, 64) - x = relay.var("x", shape=dshape) - orig = before(x) - fuse0(tvm.IRModule.from_expr(orig)) - m = fuse2(tvm.IRModule.from_expr(orig)) - relay.build(m, "llvm") - after = run_opt_pass(expected(x), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + process_and_compare(before(), expected()) def test_tuple_consecutive(): + dshape = (1, 16, 64, 64) + def gen_intermediate_tuple(x): y1 = relay.add(x, relay.const(1, "float32")) y2 = relay.add(x, relay.const(1, "float32")) @@ -389,7 +373,8 @@ def gen_consecutive_tuple(x): concat = relay.concatenate((y1, y2, y3), axis=1) return concat - def before(x): + def before(): + x = relay.var("x", shape=dshape) concat = gen_consecutive_tuple(x) pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) out = relay.add(pooled, relay.const(1, "float32")) @@ -397,7 +382,7 @@ def before(x): out_tup = relay.Tuple((out, out2)) return relay.Function(relay.analysis.free_vars(out_tup), out_tup) - def expected(dshape): + def expected(): p0 = relay.var("p0", shape=dshape) concat = gen_consecutive_tuple(p0) f0 = relay.Function([p0], concat) @@ -421,17 +406,13 @@ def expected(dshape): return relay.Function([x], relay.Tuple((z, z2))) - dshape = (1, 16, 64, 64) - x = relay.var("x", shape=dshape) - orig = before(x) - fuse0(tvm.IRModule.from_expr(orig)) - m = fuse2(tvm.IRModule.from_expr(orig)) - relay.build(m, "llvm") - after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": True}): + process_and_compare(before(), expected()) def test_inception_like(): + dshape = (1, 16, 64, 64) + def conv(data): y = relay.nn.conv2d(data, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=16) return relay.nn.relu(data=y) @@ -441,13 +422,13 @@ def inception_like(data): c1 = conv(data) return relay.concatenate((c0, c1), axis=1) - def before(dshape): + def before(): x = relay.var("x", shape=dshape) in1 = inception_like(x) in2 = inception_like(in1) return relay.Function(relay.analysis.free_vars(in2), in2) - def expected(dshape): + def expected(): p0 = relay.var("p0", shape=dshape) c = conv(p0) f0 = relay.Function(relay.analysis.free_vars(c), c) @@ -492,13 +473,7 @@ def expected(dshape): return relay.Function(relay.analysis.free_vars(out), out) - dshape = (1, 16, 64, 64) - orig = before(dshape) - fuse0(tvm.IRModule.from_expr(orig)) - m = fuse2(tvm.IRModule.from_expr(orig)) - relay.build(m, "llvm") - after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + process_and_compare(before(), expected()) def test_fuse_parallel_injective(): @@ -524,13 +499,7 @@ def expected(): y = relay.Call(f1, [x]) return relay.Function([x], y) - z = before() - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) - assert not relay.analysis.free_vars(zz) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) - assert not relay.analysis.free_vars(zz) - after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + process_and_compare(before(), expected()) def test_immutable(): @@ -558,9 +527,9 @@ def expected(): mod["main"] = relay.Function([x], y) return mod - mod = transform.InferType()(before()) - new_mod = transform.FuseOps(fuse_opt_level=2)(mod) - assert tvm.ir.structural_equal(mod, transform.InferType()(before())) + orig_mod = transform.InferType()(before()) + new_mod = transform.FuseOps(fuse_opt_level=1)(orig_mod) + assert tvm.ir.structural_equal(orig_mod, transform.InferType()(before())) assert tvm.ir.structural_equal(new_mod, transform.InferType()(expected())) @@ -608,21 +577,13 @@ def expected(n, max_fused_ops): max_fused_ops = 256 n = 300 - z = before(n) - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) - zz = run_opt_pass(z, transform.FuseOps()) - after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): + process_and_compare(before(n), expected(n, max_fused_ops)) max_fused_ops = 10 n = 20 - z = before(n) - after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) - with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): - zz = run_opt_pass(z, transform.FuseOps()) - - assert tvm.ir.structural_equal(zz, after) + process_and_compare(before(n), expected(n, max_fused_ops)) link_params = tvm.testing.parameter(False, True) @@ -654,12 +615,8 @@ def expected(link_params): y = relay.Call(f0, [x] if link_params else [x, c]) return relay.Function([x], y) - after = run_opt_pass(expected(link_params), transform.InferType()) with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": link_params}): - m = run_opt_pass(before(), transform.InferType()) - m = run_opt_pass(m, transform.FuseOps()) - assert tvm.ir.structural_equal(m, after) - relay.build(m, "llvm") + process_and_compare(before(), expected(link_params)) def test_fuse_gather_nd(link_params): @@ -688,15 +645,29 @@ def expected(link_params): y = relay.Call(f0, [x] if link_params else [x, c]) return relay.Function([x], y) - after = run_opt_pass(expected(link_params), transform.InferType()) with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": link_params}): - m = run_opt_pass(before(), transform.InferType()) - m = run_opt_pass(m, transform.FuseOps()) - assert tvm.ir.structural_equal(m, after) - relay.build(m, "llvm") + process_and_compare(before(), expected(link_params)) + + +def test_just_reduce(): + def before(): + x = relay.var("x", shape=(5, 7), dtype="float32") + z = relay.min(x) + return relay.Function([x], z) + + def expected(): + p0 = relay.var("p0", shape=(5, 7), dtype="float32") + z0 = relay.min(p0) + f0 = relay.Function([p0], z0) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + x = relay.var("x", shape=(5, 7), dtype="float32") + f = relay.Call(f0, [x]) + return relay.Function([x], f) + + process_and_compare(before(), expected()) -@tvm.testing.uses_gpu def test_fuse_bcast_reduce_scalar(): """Test fusion case with broadcast and reduction involving scalar""" @@ -717,12 +688,7 @@ def expected(): f = relay.Call(f0, [x]) return relay.Function([x], f) - orig = before() - m = fuse2(tvm.IRModule.from_expr(orig)) - for tgt, dev in tvm.testing.enabled_targets(): - relay.build(m, tgt) - after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + process_and_compare(before(), expected()) def test_fuse_max_diamond(): @@ -760,38 +726,36 @@ def create_diamond_func(inp): num_diamond = 3 with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): - fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps()) - - expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType()) - assert tvm.ir.structural_equal(fused, expected) - + process_and_compare(before(branch_len, num_diamond), after(branch_len, num_diamond)) def test_fuse_dynamic_squeeze_slice_take(): - input_data = [ - np.random.random([1, 2, 4]).astype("float32"), - np.array([0]).astype("int64"), - ] - - x = relay.var("p0107", shape=(relay.Any(), relay.Any(), 4), dtype="float32") - take_val = relay.var("p166", shape=(relay.Any(),), dtype="int64") - - squeeze = relay.op.squeeze(x, axis=[0]) - strided_slice = relay.op.strided_slice( - squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1] - ) - take = relay.op.take(strided_slice, take_val, axis=0) - - mod = tvm.IRModule.from_expr(take) - result = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm").evaluate()( - *input_data - ) - - np_result = np.squeeze(input_data[0][:, input_data[1][0], :], axis=0) + def before(): + x = relay.var("p0107", shape=(relay.Any(), relay.Any(), 4), dtype="float32") + y = relay.var("p166", shape=(relay.Any(),), dtype="int64") + squeeze = relay.op.squeeze(x, axis=[0]) + strided_slice = relay.op.strided_slice( + squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1] + ) + take = relay.op.take(strided_slice, y, axis=0) + return relay.Function([x, y], take) - assert np.allclose(result.numpy(), np_result) + def expected(): + x = relay.var("p0107", shape=(relay.Any(), relay.Any(), 4), dtype="float32") + y = relay.var("p166", shape=(relay.Any(),), dtype="int64") + xx = relay.var("xx", shape=(relay.Any(), relay.Any(), 4), dtype="float32") + yy = relay.var("yy", shape=(relay.Any(),), dtype="int64") + squeeze = relay.op.squeeze(xx, axis=[0]) + strided_slice = relay.op.strided_slice( + squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1] + ) + take = relay.op.take(strided_slice, yy, axis=0) + f = relay.Function([xx, yy], take) + f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + return relay.Function([x, y], relay.Call(f, [x, y])) + + process_and_compare(before(), expected()) -@tvm.testing.uses_gpu def test_fuse_softmax(): """Test if softmax can be fused with following ops.""" channel_size = 16 @@ -814,19 +778,61 @@ def expected(): y = relay.Call(f0, [x]) return relay.Function([x], y) - orig = before() - m = fuse2(tvm.IRModule.from_expr(orig)) - after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + process_and_compare(before(), expected()) - inp = np.random.randn(16, channel_size).astype("float32") - ref = tvm.topi.testing.softmax_python(inp).astype("float16") +def test_anchor_broadcast_reduce(): + shape = (1, 16, 64, 64) + def before(): + x = relay.var("x", shape=shape, dtype="float32") + y = relay.var("y") + a = relay.nn.conv2d(x, y, kernel_size=(3, 3), padding=(1, 1), channels=16) + b = relay.add(a, relay.const(1, "float32")) + r = relay.min(b) + return relay.Function([x, y], r) + + def expected(): + x = relay.var("x", shape=shape, dtype="float32") + y = relay.var("y") + xx = relay.var("xx", shape=shape, dtype="float32") + yy = relay.var("yy") + a = relay.nn.conv2d(xx, yy, kernel_size=(3, 3), padding=(1, 1), channels=16) + b = relay.add(a, relay.const(1, "float32")) + f1 = relay.Function([xx, yy], b) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + zz = relay.var("zz") + r = relay.min(zz) + f2 = relay.Function([zz], r) + f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + return relay.Function([x, y], relay.Call(f2, [relay.Call(f1, [x, y])])) + + with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": True}): + process_and_compare(before(), expected()) + + +def test_broadcast_reduce(): + shape = (1, 16, 64, 64) + def before(): + x = relay.var("x", shape=shape, dtype="float32") + b = relay.add(x, relay.const(1, "float32")) + r = relay.min(b) + return relay.Function([x], r) + + def expected(): + x = relay.var("x", shape=shape, dtype="float32") + xx = relay.var("xx", shape=shape, dtype="float32") + b = relay.add(xx, relay.const(1, "float32")) + r = relay.min(b) + f = relay.Function([xx], r) + f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + return relay.Function([x], relay.Call(f, [x])) + + with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": True}): + process_and_compare(before(), expected()) - for tgt, dev in tvm.testing.enabled_targets(): - ex = relay.create_executor("graph", mod=m, device=dev, target=tgt) - result = ex.evaluate()(inp).numpy() - tvm.testing.assert_allclose(result, ref, rtol=1e-4, atol=1e-4) if __name__ == "__main__": - pytest.main([__pfile__]) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))