From 01687df1597d5eee1b890773d7f48edd6eaf9cc1 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 17 Dec 2024 16:06:56 -0800 Subject: [PATCH] pointwise scheduler fails to validate reference tv (#3513) Fixes: #3512 When picking reference tv, pointwise scheduler fails to validate that the transformation on reference tv can be safely propagated to all outputs in the fusion. The issue occurs when an IterDomain that's not in the reference tv is merged with another dimension in the output tv, preventing the merge on reference tv to be propagated to the target. This PR adds an optional check `areAllOutputIdsMappedTo` in `nvfuser::pointwise_utils::DomainMap::isValidReference` The added check in this PR checks that all source producer IterDomain producing the IterDomain on outputs are covered by reference tv. This is safe for pointwise scheduler, since the scheduler checks that there's no reversible view present in the fusion. The check is optional and is disabled by transpose scheduler, where the reference_tv is not supposed to cover the entire fusion, but rather a subset of fusion IO tensors. We should extent that in future PRs. --------- Co-authored-by: Naoya Maruyama Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> --- csrc/scheduler/pointwise.cpp | 4 +- csrc/scheduler/pointwise_utils.cpp | 2 +- csrc/scheduler/pointwise_utils.h | 2 +- csrc/scheduler/tools/domain_map.cpp | 147 ++++++++- csrc/scheduler/tools/domain_map.h | 9 +- csrc/scheduler/transpose.cpp | 18 +- tests/cpp/test_matmul.cpp | 63 ++++ tests/cpp/test_pointwise.cpp | 479 +++++++++++++++++++++++++++- 8 files changed, 708 insertions(+), 16 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 9b5068c04c8..5dc0d761091 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -493,11 +493,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { int64_t max_dims = 0; for (auto inp : input_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims); } for (auto out : output_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims); } // If everything is zero dim tensors, just return. diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index be100a9f54d..20dbf6d8af5 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -19,7 +19,7 @@ TensorView* PointwiseDomainMap::findReferenceTensor( if (isValidReference(output_tv) && hasMinimumSize(output_tv, minimum_num_axes) && !output_tv->isFusionInput()) { - int64_t n_dims = pointwise_utils::nRootDims(output_tv); + int64_t n_dims = nLogicalDims(output_tv); if (n_dims > max_dims) { result = output_tv; max_dims = n_dims; diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index cc9a43d5c0f..35af977b61e 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -19,7 +19,7 @@ namespace pointwise_utils { // Returns number of non-reduction/non-broadcas/non-device dims in logical // domain -inline int64_t nRootDims(const TensorView* tv) { +inline int64_t nLogicalDims(const TensorView* tv) { auto logical_dom = tv->getLogicalDomain(); int64_t tv_n_dims = 0; for (auto dim : logical_dom) { diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 0a713d346a0..473d3821bd2 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -137,6 +137,137 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } +// Note: ideally we would want to check that reference_tv contains all iter +// domains in target_tv, so that transformation applied on reference_tv can be +// propagated to target_tv. But we don't have an easy way to check that. Instead +// of that, this function checks that all source iter domains involved in +// transformation on target_tv is covered by reference_tv. Source iter domains +// of TensorViews are IDs that doesn't have an definition and are producers of +// any IDs on the logical domain of the given TensorView. +// +// ------ +// +// e.g 0. +// T34 [i0, i1] +// T185 [i0, b2, i1] = broadcast(T34) +// T192 [i0, b3(ex), i1] = expand(T185) +// T198 [i0, b3(ex)*i1] = reshape(T192) +// output(T34) +// output(T198) +// +// if we consider taking T34 as reference_tv. T198 is the target_tv. We can't +// replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1 +// can't be reversed. The check in this function would give us T34 with source +// i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in +// T34. Hence we'll reject this reference_tv. +// +// ------ +// +// e.g 1. +// T0 [i0, i1] +// T1 [i2, i0, i1] +// T2 [i0*i1] = reshape(T0) +// T3 [b3, i0, i1] = broadcast(T0) +// T4 [i2, i0, i1] = add(T1, T3) +// output(T2) +// output(T4) +// +// the example above should be able to pick T4 as reference_tv. T2's source i0, +// i1 are both contained by the source of T4, so this example could be scheduled +// as a single fusion. +bool DomainMap::areAllTargetIdsCoveredBy( + TensorView* target_tv, + TensorView* reference_tv) const { + auto get_source_iter_domains = [this](TensorView* tv) { + // traverse back to collect all disjoint set producer IDs for each ID in the + // logical domain of tv. + VectorOfUniqueEntries>> + all_producer_sets; + std::for_each( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + [&](IterDomain* tv_logical_id) { + all_producer_sets.pushBack( + ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT)); + }); + all_producer_sets.pushBack( + ca_map_.getAllDisjointSetProducers(all_producer_sets)); + + std::vector source_ids; + // filtering all producer IDs with empty definition to get source iter + // domains + std::for_each( + all_producer_sets.vector().begin(), + all_producer_sets.vector().end(), + [&source_ids, + this](const std::shared_ptr>& + producer_set_ptr) { + IterDomain* producer_id = producer_set_ptr->front(); + if (ca_map_.uniqueExactDefinitions(producer_id).empty()) { + source_ids.push_back(producer_id); + } + }); + return source_ids; + }; + + // this contains all source iter domain that's covered by reference_tv, so + // it's safe for target_tv to have them. + std::unordered_set covered_source_ids; + for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { + covered_source_ids.insert(source_id_ref); + } + // It's safe to have unmapped broadcast IterDomain. There're quite a few tests + // expecting pointwise scheduler to handle this pattern + for (IterDomain* id_out : target_tv->getLogicalDomain()) { + if (id_out->isBroadcast()) { + NVF_ERROR( + id_out->definition() == nullptr || + id_out->definition()->isA()); + + // Note that ideally we should also be able to handle merge/split on + // broadcast IDs, so we should really move this skip inside the loop below + // `get_source_iter_domains(target_tv)` and skip broadcast source IDs. + // currently we have the issue that split/merge does not preserve expanded + // broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126 + covered_source_ids.insert(id_out); + } + } + // Note: there's certain cases where it's safe to have dangling IDs, + // e.g + // T34 [i0, i1] + // T185 [i0, b2, i1] = broadcast(T34) + // T192 [i0, b3(ex), i1] = expand(T185) + // It's safe to propagate T34 to T192, since b3(ex) is not involved in the + // propagation. But this isn't generally safe. If the above example is changed + // to e.g + // T34 [i0, i1] + // T185 [i0, b2, i1] = broadcast(T34) + // T186 [i0, i4, i1] = ones({i0, i4, i1}) + // T193 [i0, i4, i1] = add(T185, T186) + // It's unsafe to propagate from T34 to T193, see issue + // https://github.com/NVIDIA/Fuser/issues/3542 + + // Check all source iter domain involved in producing target_tv + for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) { + // NOTE: we use concrete id instead. This allows us to link indirect + // broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1] + // T3[i0, i9] = pad(T0[i0, b0]) + // We have i9 in T3 + // -> source ID b0 + // -> concrete map to i1 + // So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1` + auto concrete_source_id_out = + ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE); + // if we find any source_id_out that's not contained, it's possible our + // propagation would fail since transformation involving this iter domain + // can't be resolved. + if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) { + return false; + } + } + return true; +} + // Reference domains must exactly match with the input domains. See // also PR #661 IterDomain* DomainMap::getMappedInputConcreteID( @@ -228,7 +359,7 @@ IterDomain* DomainMap::anyMapped( } // Determine if output TensorView is a valid reference tensor for this fusion. -// The reference tensor must map to all the iterDomains in each input. +// The reference tensor must map to all the iterDomains in each input and output bool DomainMap::isValidReference(TensorView* tv) const { for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { if (input_tv->uses().empty()) { @@ -240,6 +371,20 @@ bool DomainMap::isValidReference(TensorView* tv) const { return false; } } + // The check on outputs are optional, transpose scheduler might propose a + // secondary reference that only applies to a subset of IO tensors. Ideally we + // should have a more robust check and consider the IO groups instead of + // blindly skip outputs. + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { + // no need to check for self. + if (output_tv == tv) { + continue; + } + if (!areAllTargetIdsCoveredBy(output_tv, tv)) { + return false; + } + } return true; } diff --git a/csrc/scheduler/tools/domain_map.h b/csrc/scheduler/tools/domain_map.h index 88dadcba721..61833eda33d 100644 --- a/csrc/scheduler/tools/domain_map.h +++ b/csrc/scheduler/tools/domain_map.h @@ -32,7 +32,8 @@ class DomainMap { } // Determine if a TensorView is a valid reference tensor for this fusion. - // The reference tensor must map to all the iterDomains in each input. + // The reference tensor must map to all the iterDomains in each input and + // output. bool isValidReference(TensorView* tv) const; protected: @@ -40,6 +41,12 @@ class DomainMap { bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) const; + // Determine if all source IterDomains in target_tv are contained by the + // reference_tv, this ensures transformations from reference_tv can be + // propagated to target_tv + bool areAllTargetIdsCoveredBy(TensorView* target_tv, TensorView* reference_tv) + const; + virtual IterDomain* getMappedInputConcreteID( const std::unordered_set& in_concrete_ids, IterDomain* out_id) const; diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 95100c7fc5f..1db2a3cc85b 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -170,8 +170,16 @@ class TransposeDomainMap : public scheduler_tools::DomainMap { TensorView* result = nullptr; int64_t max_dims = -1; for (auto tv : group) { + // since transpose scheduler have different set of reference, we skip IDs + // coverage check of the reference on outputs of the fusion. Note that + // this is not ideal, we would want to instead have reference tensor + // checked against all its target IO tensors. + // TODO: open an issue for this one. transpose scheduler is not supposed + // to reuse pointwise_utils::DomainMap::isValidRefrence. This function is + // too restrictive and doesn't align well with the scheme of transpose + // scheduler if (isValidReference(tv)) { - int64_t dims = (int64_t)pointwise_utils::nRootDims(tv); + int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv); if (dims > max_dims) { result = tv; max_dims = dims; @@ -992,12 +1000,12 @@ std::unique_ptr getTransposeHeuristics( << "max_io_dtype_size: " << max_io_dtype_size << "\n" << "group 1: " << ir_utils::toString(grouped_inputs_outputs[0]) << "\n" - << "reference1: " << reference1 << "\n" + << "reference1: " << reference1->toString() << "\n" << "inner_most_id1 position: " << inner_most_pos1_in_ref1 << " (in reference 1)\n" << "group 2: " << ir_utils::toString(grouped_inputs_outputs[1]) << "\n" - << "reference2: " << reference2 << "\n" + << "reference2: " << reference2->toString() << "\n" << "inner_most_id2 position: " << inner_most_pos2_in_ref1 << " (in reference 1)" << std::endl; if (hasSmallTransposeDimensions(tparams)) { @@ -1047,11 +1055,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { int64_t max_dims = 0; for (auto inp : input_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims); } for (auto out : output_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims); } // If everything is zero dim tensors, just return. diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index cbd51d97cfb..e132bb6208e 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3993,4 +3993,67 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); + auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {0}); + + // Reorder the accumulator as [M, N, K] + // [K, M, N] -> [M, N, K] + tv2->reorder({{-3, -1}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({K, M, 1}, options); + auto b_ref = at::randn({K, 1, N}, options); + auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + std::cout << mparams.toString() << std::endl; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + } // namespace nvfuser diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index bb1c6bd7bfb..e24f37d7348 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,15 @@ bool hasVectorizationCache(TensorView* tv) { return false; } +class DomainMapUnitTest : public scheduler_tools::DomainMap { + public: + DomainMapUnitTest(Fusion* fusion) : scheduler_tools::DomainMap(fusion) {}; + bool testTargetCoverage(TensorView* target_tv, TensorView* reference_tv) + const { + return areAllTargetIdsCoveredBy(target_tv, reference_tv); + } +}; + } // namespace TEST_F(PointwiseTest, VectorizeStrideContiguity2D) { @@ -306,7 +316,7 @@ TEST_F(PointwiseTest, Issue1567VectorizeAllocationDomain) { at::Tensor input1 = at::empty_strided({1, 128, 1}, {128, 1, 128}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -340,7 +350,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase0) { at::Tensor input1 = at::randn({1024, 2, 512}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs, false); auto pparams = cg_results.heuristic_params->as(); @@ -374,7 +384,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase1) { at::Tensor input1 = at::randn({1024, 512, 2}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -414,7 +424,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase2) { at::Tensor input1 = at::empty_strided({1024, 512, 2}, {2, 2048, 1}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -451,7 +461,7 @@ TEST_F(PointwiseTest, VIssue1567ectorizationFactorAnalysisCase3) { at::Tensor input1 = at::randn({512, 1024, 2}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -773,4 +783,463 @@ TEST_F(PointwiseTest, VectorizePadLoweringPermuted) { EXPECT_TRUE(found_vectorize); testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } + +TEST_F(PointwiseTest, DomainMapTestEg0) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i0, i1} + auto tv1 = relu(tv0); + fusion->addOutput(tv1); + // tv2 {i0, b2, i1} + auto tv2 = broadcast(tv1, {false, true, false}); + // tv3 {i0, b3{1 ex 4}, i1} + auto tv3 = expand( + tv2, + {tv2->axis(0)->extent(), + IrBuilder::create(4), + tv2->axis(2)->extent()}); + // NOTE hat currently expand doesn't introduce an iter domain operation, so + // we don't see that i4 is produced by realizing the expanded extent of b3{1 + // ex 4} tv4 {i0, i4*i1} + auto tv4 = reshape(tv3, {2, 4, 3}, {2, 12}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv4 is not covered by tv1, because the expanded ID i4 participates in + // transformation + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv1)); + + // tv3 is not covered by tv1, because the missing ID b3{1 ex 4} is concretized + // as i4, which is not mapped on tv1 + EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv1)); + + // tv1 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv4)); + + // tv1 is not a valid reference + EXPECT_FALSE(domain_map.isValidReference(tv1)); + + // tv4 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 7}, options); + std::vector aten_inputs = {t0}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapTestEg1) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i2, i0, i1} + TensorView* tv1 = makeContigTensor(3); + fusion->addInput(tv1); + // tv2 {i0*i1} + auto tv2 = reshape(tv0, {2, 4}, {8}); + fusion->addOutput(tv2); + + // tv3 {b3, i0, i1} + auto tv3 = broadcast(tv0, {true, false, false}); + // tv4 {i2, i0, i1} + auto tv4 = add(tv1, tv3); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv4 is not covered by tv2, because it misses i2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); + + // tv2 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + // tv2 is not a valid reference + EXPECT_FALSE(domain_map.isValidReference(tv2)); + + // tv4 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({3, 2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapTestEg2) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i0, i1} + auto tv1 = relu(tv0); + fusion->addOutput(tv1); + // tv2 {i0, b2, i1} + auto tv2 = broadcast(tv1, {false, true, false}); + // tv3 {i0, b3{1 ex 4}, i1} + auto tv3 = expand( + tv2, + {tv2->axis(0)->extent(), + IrBuilder::create(4), + tv2->axis(2)->extent()}); + fusion->addOutput(tv3); + + DomainMapUnitTest domain_map(fusion); + // tv3 is covered by tv1, because the missing ID b3{1 ex 4} is broadcast and + // doesn't get resolved to a concrete broadcast ID. + EXPECT_TRUE(domain_map.testTargetCoverage(tv3, tv1)); + + // tv1 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv3)); + + // tv1 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv1)); + + // tv3 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv3)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 7}, options); + std::vector aten_inputs = {t0}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapFactory) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv1 {i1} + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + // tv1 {i0, i1} + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv1); + + // tv2 {b2, b3, i1} + auto tv2 = broadcast(tv0, {true, true, false}); + // NOTE tv1 will be broadcasted to {b2, i0, i1} before the add. + // tv3 {b2, i0, i1} + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + auto size_val = IrBuilder::create(4.0, DataType::Int); + auto one_val = IrBuilder::create(1, DataType::Int); + // factory method creates an iter domain out of thin air + // tv4 {i4{4}, b4, i1} + auto tv4 = ones({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); + // tv5 {i4{4}, i0, i1} + auto tv5 = mul(tv2, tv4); + fusion->addOutput(tv5); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is not covered by tv3, because it's missing i4{4} + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv3)); + // tv1 is not covered by tv4, since it's missing i0 + EXPECT_FALSE(domain_map.testTargetCoverage(tv1, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv3)); + // tv5 has the same IDs as tv4, and is not a valid reference. + EXPECT_FALSE(domain_map.isValidReference(tv5)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::empty_strided({25}, {1}, options); + at::Tensor input1 = at::empty_strided({7, 25}, {25, 1}, options); + auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + SegmentedFusion* segmented_fusion = runtime->fusionSegments(); + // This fusion currently cannot be scheduled as a single kernel. It is + // expected to be segmented as: g{(pointwise) + // inputs: tv0, tv1 + // outputs: tv2, tv3 + // tv2 = broadcast(tv0) + // tv3 = add (tv2, broadcast(tv1)) + // } + // + // g{(pointwise) + // inputs: tv2 + // outputs: tv5 + // tv4 = full({4, 1, i0}) + // tv5 = mul(tv2, tv4) + // } + EXPECT_EQ(segmented_fusion->groups().size(), 2); + + for (SegmentedGroup* group : segmented_fusion->groups()) { + const std::vector& exprs = group->exprs(); + + size_t num_full = std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + if (num_full != 0) { + // this is the segment contains the factory op. + EXPECT_EQ(exprs.size(), 2); + EXPECT_EQ(num_full, 1); + auto binary_op_iter = + std::find_if(exprs.begin(), exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + EXPECT_EQ( + (*binary_op_iter)->as()->getBinaryOpType(), + BinaryOpType::Mul); + Fusion* group_fusion = group->getFusion(); + // validate that we have a valid reference in the segmented fusion + DomainMapUnitTest group_dm(group_fusion); + EXPECT_EQ(group_fusion->outputs().size(), 1); + EXPECT_TRUE(group_dm.isValidReference( + group_fusion->outputs()[0]->as())); + } else { + // validate segmentation has the correct ops + EXPECT_EQ(exprs.size(), 3); + EXPECT_EQ( + std::count_if( + exprs.begin(), + exprs.end(), + [](Expr* expr) { return expr->isA(); }), + 2); + EXPECT_EQ( + std::count_if( + exprs.begin(), + exprs.end(), + [](Expr* expr) { return expr->isA(); }), + 1); + Fusion* group_fusion = group->getFusion(); + auto output_add = std::find_if( + group_fusion->outputs().begin(), + group_fusion->outputs().end(), + [](Val* val) { return val->definition()->isA(); }); + EXPECT_TRUE(output_add != group_fusion->outputs().end()); + DomainMapUnitTest group_dm(group_fusion); + // validate that the segmented fusion choose the add output as the + // reference + EXPECT_TRUE(group_dm.isValidReference((*output_add)->as())); + } + } + + testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapPad0) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {b1, i0} + TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); + fusion->addInput(tv0); + // tv1 {i2, b1, i0} + TensorView* tv1 = TensorViewBuilder().shape({-1, 1, -1}).build(); + fusion->addInput(tv1); + // tv2 {i2, b1, i0} + auto tv2 = add(tv1, tv0); + fusion->addOutput(tv2); + // i3 = resize(b1 + 4 + 4) + // tv3 {i3, i0} + auto tv3 = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(4L), + IrBuilder::create(4L)}); + // tv4 {i3*i0} + auto tv4 = reshape(tv3, {9, 5}, {45}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is covered by tv2, because i3 is produced by b1 + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); + // tv2 is not covered by tv4, it's missing i2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv4)); + EXPECT_TRUE(domain_map.isValidReference(tv2)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor t1 = at::empty_strided({7, 1, 5}, {5, 5, 1}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapPad1) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {b1, i0} + TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); + fusion->addInput(tv0); + // tv1 {i2, i3, i4, b5} + TensorView* tv1 = TensorViewBuilder().shape({-1, -1, -1, 1}).build(); + fusion->addInput(tv1); + + // tv2 {b6, b7, b1, i0} + auto tv2 = broadcast(tv0, {true, true, false, false}); + // tv3 {i2, i3, i4, i0} + auto tv3 = add(tv1, tv2); + fusion->addOutput(tv3); + // i8 = resize(b1 + 4 + 4) + // tv4 {i8, i0} + auto tv4 = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(4L), + IrBuilder::create(4L)}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is covered by tv3, because i8 is produced by b1, a broadcast dimension + // concretized as i4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv3)); + // tv3 is not covered by tv4, it's missing i2 and i3 + EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv4)); + EXPECT_TRUE(domain_map.isValidReference(tv3)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor t1 = at::empty_strided({2, 3, 4, 1}, {12, 4, 1, 1}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapSlice0) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i1, i0} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i1, i0} + // use concrete tensor to avoid need of concretization + TensorView* tv1 = makeContigConcreteTensor({2, 4}); + fusion->addInput(tv1); + + // b3 = resize(i0 + 0 - 3) + // tv2 {i1, b2} + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(0L), + IrBuilder::create(1L), + IrBuilder::create(1L)}}); + fusion->addOutput(tv2); + // tv3 {i1, i0} + auto tv3 = add(tv0, tv1); + // tv4 {i1*i0} + auto tv4 = reshape(tv3, {2, 4}, {8}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv2 and tv4 has the same source IDs, since b3 = resize(i0 + 0 - 3) + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_TRUE(domain_map.isValidReference(tv2)); + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapSlice1) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i2, i1, i0} + TensorView* tv0 = makeContigTensor(3); + fusion->addInput(tv0); + // tv1 {i1, i0} + // use concrete tensor to avoid need of concretization + TensorView* tv1 = makeContigConcreteTensor({2, 4}); + fusion->addInput(tv1); + + // b3 = resize(i0 + 0 - 3) + // tv2 {i1, b3} + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(0L), + IrBuilder::create(1L), + IrBuilder::create(1L)}}); + fusion->addOutput(tv2); + // tv3 {i2, i1, i0} + auto tv3 = add(tv0, tv1); + // tv4 {i2, i1*i0} + auto tv4 = reshape(tv3, {2, 2, 4}, {2, 8}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // i2 is missing in tv2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv2)); + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 4}, options); + at::Tensor t1 = at::randn({2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + } // namespace nvfuser