diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 9b5068c04c8..5dc0d761091 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -493,11 +493,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { int64_t max_dims = 0; for (auto inp : input_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims); } for (auto out : output_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims); } // If everything is zero dim tensors, just return. diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index be100a9f54d..20dbf6d8af5 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -19,7 +19,7 @@ TensorView* PointwiseDomainMap::findReferenceTensor( if (isValidReference(output_tv) && hasMinimumSize(output_tv, minimum_num_axes) && !output_tv->isFusionInput()) { - int64_t n_dims = pointwise_utils::nRootDims(output_tv); + int64_t n_dims = nLogicalDims(output_tv); if (n_dims > max_dims) { result = output_tv; max_dims = n_dims; diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index cc9a43d5c0f..35af977b61e 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -19,7 +19,7 @@ namespace pointwise_utils { // Returns number of non-reduction/non-broadcas/non-device dims in logical // domain -inline int64_t nRootDims(const TensorView* tv) { +inline int64_t nLogicalDims(const TensorView* tv) { auto logical_dom = tv->getLogicalDomain(); int64_t tv_n_dims = 0; for (auto dim : logical_dom) { diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 0a713d346a0..473d3821bd2 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -137,6 +137,137 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) return in_concrete_ids.empty(); } +// Note: ideally we would want to check that reference_tv contains all iter +// domains in target_tv, so that transformation applied on reference_tv can be +// propagated to target_tv. But we don't have an easy way to check that. Instead +// of that, this function checks that all source iter domains involved in +// transformation on target_tv is covered by reference_tv. Source iter domains +// of TensorViews are IDs that doesn't have an definition and are producers of +// any IDs on the logical domain of the given TensorView. +// +// ------ +// +// e.g 0. +// T34 [i0, i1] +// T185 [i0, b2, i1] = broadcast(T34) +// T192 [i0, b3(ex), i1] = expand(T185) +// T198 [i0, b3(ex)*i1] = reshape(T192) +// output(T34) +// output(T198) +// +// if we consider taking T34 as reference_tv. T198 is the target_tv. We can't +// replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1 +// can't be reversed. The check in this function would give us T34 with source +// i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in +// T34. Hence we'll reject this reference_tv. +// +// ------ +// +// e.g 1. +// T0 [i0, i1] +// T1 [i2, i0, i1] +// T2 [i0*i1] = reshape(T0) +// T3 [b3, i0, i1] = broadcast(T0) +// T4 [i2, i0, i1] = add(T1, T3) +// output(T2) +// output(T4) +// +// the example above should be able to pick T4 as reference_tv. T2's source i0, +// i1 are both contained by the source of T4, so this example could be scheduled +// as a single fusion. +bool DomainMap::areAllTargetIdsCoveredBy( + TensorView* target_tv, + TensorView* reference_tv) const { + auto get_source_iter_domains = [this](TensorView* tv) { + // traverse back to collect all disjoint set producer IDs for each ID in the + // logical domain of tv. + VectorOfUniqueEntries>> + all_producer_sets; + std::for_each( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + [&](IterDomain* tv_logical_id) { + all_producer_sets.pushBack( + ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT)); + }); + all_producer_sets.pushBack( + ca_map_.getAllDisjointSetProducers(all_producer_sets)); + + std::vector source_ids; + // filtering all producer IDs with empty definition to get source iter + // domains + std::for_each( + all_producer_sets.vector().begin(), + all_producer_sets.vector().end(), + [&source_ids, + this](const std::shared_ptr>& + producer_set_ptr) { + IterDomain* producer_id = producer_set_ptr->front(); + if (ca_map_.uniqueExactDefinitions(producer_id).empty()) { + source_ids.push_back(producer_id); + } + }); + return source_ids; + }; + + // this contains all source iter domain that's covered by reference_tv, so + // it's safe for target_tv to have them. + std::unordered_set covered_source_ids; + for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { + covered_source_ids.insert(source_id_ref); + } + // It's safe to have unmapped broadcast IterDomain. There're quite a few tests + // expecting pointwise scheduler to handle this pattern + for (IterDomain* id_out : target_tv->getLogicalDomain()) { + if (id_out->isBroadcast()) { + NVF_ERROR( + id_out->definition() == nullptr || + id_out->definition()->isA()); + + // Note that ideally we should also be able to handle merge/split on + // broadcast IDs, so we should really move this skip inside the loop below + // `get_source_iter_domains(target_tv)` and skip broadcast source IDs. + // currently we have the issue that split/merge does not preserve expanded + // broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126 + covered_source_ids.insert(id_out); + } + } + // Note: there's certain cases where it's safe to have dangling IDs, + // e.g + // T34 [i0, i1] + // T185 [i0, b2, i1] = broadcast(T34) + // T192 [i0, b3(ex), i1] = expand(T185) + // It's safe to propagate T34 to T192, since b3(ex) is not involved in the + // propagation. But this isn't generally safe. If the above example is changed + // to e.g + // T34 [i0, i1] + // T185 [i0, b2, i1] = broadcast(T34) + // T186 [i0, i4, i1] = ones({i0, i4, i1}) + // T193 [i0, i4, i1] = add(T185, T186) + // It's unsafe to propagate from T34 to T193, see issue + // https://github.com/NVIDIA/Fuser/issues/3542 + + // Check all source iter domain involved in producing target_tv + for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) { + // NOTE: we use concrete id instead. This allows us to link indirect + // broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1] + // T3[i0, i9] = pad(T0[i0, b0]) + // We have i9 in T3 + // -> source ID b0 + // -> concrete map to i1 + // So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1` + auto concrete_source_id_out = + ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE); + // if we find any source_id_out that's not contained, it's possible our + // propagation would fail since transformation involving this iter domain + // can't be resolved. + if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) { + return false; + } + } + return true; +} + // Reference domains must exactly match with the input domains. See // also PR #661 IterDomain* DomainMap::getMappedInputConcreteID( @@ -228,7 +359,7 @@ IterDomain* DomainMap::anyMapped( } // Determine if output TensorView is a valid reference tensor for this fusion. -// The reference tensor must map to all the iterDomains in each input. +// The reference tensor must map to all the iterDomains in each input and output bool DomainMap::isValidReference(TensorView* tv) const { for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { if (input_tv->uses().empty()) { @@ -240,6 +371,20 @@ bool DomainMap::isValidReference(TensorView* tv) const { return false; } } + // The check on outputs are optional, transpose scheduler might propose a + // secondary reference that only applies to a subset of IO tensors. Ideally we + // should have a more robust check and consider the IO groups instead of + // blindly skip outputs. + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { + // no need to check for self. + if (output_tv == tv) { + continue; + } + if (!areAllTargetIdsCoveredBy(output_tv, tv)) { + return false; + } + } return true; } diff --git a/csrc/scheduler/tools/domain_map.h b/csrc/scheduler/tools/domain_map.h index 88dadcba721..61833eda33d 100644 --- a/csrc/scheduler/tools/domain_map.h +++ b/csrc/scheduler/tools/domain_map.h @@ -32,7 +32,8 @@ class DomainMap { } // Determine if a TensorView is a valid reference tensor for this fusion. - // The reference tensor must map to all the iterDomains in each input. + // The reference tensor must map to all the iterDomains in each input and + // output. bool isValidReference(TensorView* tv) const; protected: @@ -40,6 +41,12 @@ class DomainMap { bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) const; + // Determine if all source IterDomains in target_tv are contained by the + // reference_tv, this ensures transformations from reference_tv can be + // propagated to target_tv + bool areAllTargetIdsCoveredBy(TensorView* target_tv, TensorView* reference_tv) + const; + virtual IterDomain* getMappedInputConcreteID( const std::unordered_set& in_concrete_ids, IterDomain* out_id) const; diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 95100c7fc5f..1db2a3cc85b 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -170,8 +170,16 @@ class TransposeDomainMap : public scheduler_tools::DomainMap { TensorView* result = nullptr; int64_t max_dims = -1; for (auto tv : group) { + // since transpose scheduler have different set of reference, we skip IDs + // coverage check of the reference on outputs of the fusion. Note that + // this is not ideal, we would want to instead have reference tensor + // checked against all its target IO tensors. + // TODO: open an issue for this one. transpose scheduler is not supposed + // to reuse pointwise_utils::DomainMap::isValidRefrence. This function is + // too restrictive and doesn't align well with the scheme of transpose + // scheduler if (isValidReference(tv)) { - int64_t dims = (int64_t)pointwise_utils::nRootDims(tv); + int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv); if (dims > max_dims) { result = tv; max_dims = dims; @@ -992,12 +1000,12 @@ std::unique_ptr getTransposeHeuristics( << "max_io_dtype_size: " << max_io_dtype_size << "\n" << "group 1: " << ir_utils::toString(grouped_inputs_outputs[0]) << "\n" - << "reference1: " << reference1 << "\n" + << "reference1: " << reference1->toString() << "\n" << "inner_most_id1 position: " << inner_most_pos1_in_ref1 << " (in reference 1)\n" << "group 2: " << ir_utils::toString(grouped_inputs_outputs[1]) << "\n" - << "reference2: " << reference2 << "\n" + << "reference2: " << reference2->toString() << "\n" << "inner_most_id2 position: " << inner_most_pos2_in_ref1 << " (in reference 1)" << std::endl; if (hasSmallTransposeDimensions(tparams)) { @@ -1047,11 +1055,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { int64_t max_dims = 0; for (auto inp : input_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims); } for (auto out : output_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims); } // If everything is zero dim tensors, just return. diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index cbd51d97cfb..e132bb6208e 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3993,4 +3993,67 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); + auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {0}); + + // Reorder the accumulator as [M, N, K] + // [K, M, N] -> [M, N, K] + tv2->reorder({{-3, -1}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({K, M, 1}, options); + auto b_ref = at::randn({K, 1, N}, options); + auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + std::cout << mparams.toString() << std::endl; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + } // namespace nvfuser diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index bb1c6bd7bfb..e24f37d7348 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,15 @@ bool hasVectorizationCache(TensorView* tv) { return false; } +class DomainMapUnitTest : public scheduler_tools::DomainMap { + public: + DomainMapUnitTest(Fusion* fusion) : scheduler_tools::DomainMap(fusion) {}; + bool testTargetCoverage(TensorView* target_tv, TensorView* reference_tv) + const { + return areAllTargetIdsCoveredBy(target_tv, reference_tv); + } +}; + } // namespace TEST_F(PointwiseTest, VectorizeStrideContiguity2D) { @@ -306,7 +316,7 @@ TEST_F(PointwiseTest, Issue1567VectorizeAllocationDomain) { at::Tensor input1 = at::empty_strided({1, 128, 1}, {128, 1, 128}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -340,7 +350,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase0) { at::Tensor input1 = at::randn({1024, 2, 512}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs, false); auto pparams = cg_results.heuristic_params->as(); @@ -374,7 +384,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase1) { at::Tensor input1 = at::randn({1024, 512, 2}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -414,7 +424,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase2) { at::Tensor input1 = at::empty_strided({1024, 512, 2}, {2, 2048, 1}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -451,7 +461,7 @@ TEST_F(PointwiseTest, VIssue1567ectorizationFactorAnalysisCase3) { at::Tensor input1 = at::randn({512, 1024, 2}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -773,4 +783,463 @@ TEST_F(PointwiseTest, VectorizePadLoweringPermuted) { EXPECT_TRUE(found_vectorize); testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } + +TEST_F(PointwiseTest, DomainMapTestEg0) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i0, i1} + auto tv1 = relu(tv0); + fusion->addOutput(tv1); + // tv2 {i0, b2, i1} + auto tv2 = broadcast(tv1, {false, true, false}); + // tv3 {i0, b3{1 ex 4}, i1} + auto tv3 = expand( + tv2, + {tv2->axis(0)->extent(), + IrBuilder::create(4), + tv2->axis(2)->extent()}); + // NOTE hat currently expand doesn't introduce an iter domain operation, so + // we don't see that i4 is produced by realizing the expanded extent of b3{1 + // ex 4} tv4 {i0, i4*i1} + auto tv4 = reshape(tv3, {2, 4, 3}, {2, 12}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv4 is not covered by tv1, because the expanded ID i4 participates in + // transformation + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv1)); + + // tv3 is not covered by tv1, because the missing ID b3{1 ex 4} is concretized + // as i4, which is not mapped on tv1 + EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv1)); + + // tv1 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv4)); + + // tv1 is not a valid reference + EXPECT_FALSE(domain_map.isValidReference(tv1)); + + // tv4 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 7}, options); + std::vector aten_inputs = {t0}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapTestEg1) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i2, i0, i1} + TensorView* tv1 = makeContigTensor(3); + fusion->addInput(tv1); + // tv2 {i0*i1} + auto tv2 = reshape(tv0, {2, 4}, {8}); + fusion->addOutput(tv2); + + // tv3 {b3, i0, i1} + auto tv3 = broadcast(tv0, {true, false, false}); + // tv4 {i2, i0, i1} + auto tv4 = add(tv1, tv3); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv4 is not covered by tv2, because it misses i2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); + + // tv2 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + // tv2 is not a valid reference + EXPECT_FALSE(domain_map.isValidReference(tv2)); + + // tv4 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({3, 2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapTestEg2) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i0, i1} + auto tv1 = relu(tv0); + fusion->addOutput(tv1); + // tv2 {i0, b2, i1} + auto tv2 = broadcast(tv1, {false, true, false}); + // tv3 {i0, b3{1 ex 4}, i1} + auto tv3 = expand( + tv2, + {tv2->axis(0)->extent(), + IrBuilder::create(4), + tv2->axis(2)->extent()}); + fusion->addOutput(tv3); + + DomainMapUnitTest domain_map(fusion); + // tv3 is covered by tv1, because the missing ID b3{1 ex 4} is broadcast and + // doesn't get resolved to a concrete broadcast ID. + EXPECT_TRUE(domain_map.testTargetCoverage(tv3, tv1)); + + // tv1 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv3)); + + // tv1 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv1)); + + // tv3 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv3)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 7}, options); + std::vector aten_inputs = {t0}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapFactory) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv1 {i1} + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + // tv1 {i0, i1} + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv1); + + // tv2 {b2, b3, i1} + auto tv2 = broadcast(tv0, {true, true, false}); + // NOTE tv1 will be broadcasted to {b2, i0, i1} before the add. + // tv3 {b2, i0, i1} + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + auto size_val = IrBuilder::create(4.0, DataType::Int); + auto one_val = IrBuilder::create(1, DataType::Int); + // factory method creates an iter domain out of thin air + // tv4 {i4{4}, b4, i1} + auto tv4 = ones({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); + // tv5 {i4{4}, i0, i1} + auto tv5 = mul(tv2, tv4); + fusion->addOutput(tv5); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is not covered by tv3, because it's missing i4{4} + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv3)); + // tv1 is not covered by tv4, since it's missing i0 + EXPECT_FALSE(domain_map.testTargetCoverage(tv1, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv3)); + // tv5 has the same IDs as tv4, and is not a valid reference. + EXPECT_FALSE(domain_map.isValidReference(tv5)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::empty_strided({25}, {1}, options); + at::Tensor input1 = at::empty_strided({7, 25}, {25, 1}, options); + auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + SegmentedFusion* segmented_fusion = runtime->fusionSegments(); + // This fusion currently cannot be scheduled as a single kernel. It is + // expected to be segmented as: g{(pointwise) + // inputs: tv0, tv1 + // outputs: tv2, tv3 + // tv2 = broadcast(tv0) + // tv3 = add (tv2, broadcast(tv1)) + // } + // + // g{(pointwise) + // inputs: tv2 + // outputs: tv5 + // tv4 = full({4, 1, i0}) + // tv5 = mul(tv2, tv4) + // } + EXPECT_EQ(segmented_fusion->groups().size(), 2); + + for (SegmentedGroup* group : segmented_fusion->groups()) { + const std::vector& exprs = group->exprs(); + + size_t num_full = std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + if (num_full != 0) { + // this is the segment contains the factory op. + EXPECT_EQ(exprs.size(), 2); + EXPECT_EQ(num_full, 1); + auto binary_op_iter = + std::find_if(exprs.begin(), exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + EXPECT_EQ( + (*binary_op_iter)->as()->getBinaryOpType(), + BinaryOpType::Mul); + Fusion* group_fusion = group->getFusion(); + // validate that we have a valid reference in the segmented fusion + DomainMapUnitTest group_dm(group_fusion); + EXPECT_EQ(group_fusion->outputs().size(), 1); + EXPECT_TRUE(group_dm.isValidReference( + group_fusion->outputs()[0]->as())); + } else { + // validate segmentation has the correct ops + EXPECT_EQ(exprs.size(), 3); + EXPECT_EQ( + std::count_if( + exprs.begin(), + exprs.end(), + [](Expr* expr) { return expr->isA(); }), + 2); + EXPECT_EQ( + std::count_if( + exprs.begin(), + exprs.end(), + [](Expr* expr) { return expr->isA(); }), + 1); + Fusion* group_fusion = group->getFusion(); + auto output_add = std::find_if( + group_fusion->outputs().begin(), + group_fusion->outputs().end(), + [](Val* val) { return val->definition()->isA(); }); + EXPECT_TRUE(output_add != group_fusion->outputs().end()); + DomainMapUnitTest group_dm(group_fusion); + // validate that the segmented fusion choose the add output as the + // reference + EXPECT_TRUE(group_dm.isValidReference((*output_add)->as())); + } + } + + testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapPad0) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {b1, i0} + TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); + fusion->addInput(tv0); + // tv1 {i2, b1, i0} + TensorView* tv1 = TensorViewBuilder().shape({-1, 1, -1}).build(); + fusion->addInput(tv1); + // tv2 {i2, b1, i0} + auto tv2 = add(tv1, tv0); + fusion->addOutput(tv2); + // i3 = resize(b1 + 4 + 4) + // tv3 {i3, i0} + auto tv3 = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(4L), + IrBuilder::create(4L)}); + // tv4 {i3*i0} + auto tv4 = reshape(tv3, {9, 5}, {45}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is covered by tv2, because i3 is produced by b1 + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); + // tv2 is not covered by tv4, it's missing i2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv4)); + EXPECT_TRUE(domain_map.isValidReference(tv2)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor t1 = at::empty_strided({7, 1, 5}, {5, 5, 1}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapPad1) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {b1, i0} + TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); + fusion->addInput(tv0); + // tv1 {i2, i3, i4, b5} + TensorView* tv1 = TensorViewBuilder().shape({-1, -1, -1, 1}).build(); + fusion->addInput(tv1); + + // tv2 {b6, b7, b1, i0} + auto tv2 = broadcast(tv0, {true, true, false, false}); + // tv3 {i2, i3, i4, i0} + auto tv3 = add(tv1, tv2); + fusion->addOutput(tv3); + // i8 = resize(b1 + 4 + 4) + // tv4 {i8, i0} + auto tv4 = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(4L), + IrBuilder::create(4L)}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is covered by tv3, because i8 is produced by b1, a broadcast dimension + // concretized as i4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv3)); + // tv3 is not covered by tv4, it's missing i2 and i3 + EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv4)); + EXPECT_TRUE(domain_map.isValidReference(tv3)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor t1 = at::empty_strided({2, 3, 4, 1}, {12, 4, 1, 1}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapSlice0) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i1, i0} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i1, i0} + // use concrete tensor to avoid need of concretization + TensorView* tv1 = makeContigConcreteTensor({2, 4}); + fusion->addInput(tv1); + + // b3 = resize(i0 + 0 - 3) + // tv2 {i1, b2} + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(0L), + IrBuilder::create(1L), + IrBuilder::create(1L)}}); + fusion->addOutput(tv2); + // tv3 {i1, i0} + auto tv3 = add(tv0, tv1); + // tv4 {i1*i0} + auto tv4 = reshape(tv3, {2, 4}, {8}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv2 and tv4 has the same source IDs, since b3 = resize(i0 + 0 - 3) + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_TRUE(domain_map.isValidReference(tv2)); + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapSlice1) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i2, i1, i0} + TensorView* tv0 = makeContigTensor(3); + fusion->addInput(tv0); + // tv1 {i1, i0} + // use concrete tensor to avoid need of concretization + TensorView* tv1 = makeContigConcreteTensor({2, 4}); + fusion->addInput(tv1); + + // b3 = resize(i0 + 0 - 3) + // tv2 {i1, b3} + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(0L), + IrBuilder::create(1L), + IrBuilder::create(1L)}}); + fusion->addOutput(tv2); + // tv3 {i2, i1, i0} + auto tv3 = add(tv0, tv1); + // tv4 {i2, i1*i0} + auto tv4 = reshape(tv3, {2, 2, 4}, {2, 8}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // i2 is missing in tv2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv2)); + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 4}, options); + at::Tensor t1 = at::randn({2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + } // namespace nvfuser