Skip to content

Commit

Permalink
pointwise scheduler fails to validate reference tv (#3513)
Browse files Browse the repository at this point in the history
Fixes: #3512

When picking reference tv, pointwise scheduler fails to validate that
the transformation on reference tv can be safely propagated to all
outputs in the fusion. The issue occurs when an IterDomain that's not in
the reference tv is merged with another dimension in the output tv,
preventing the merge on reference tv to be propagated to the target.

This PR adds an optional check `areAllOutputIdsMappedTo` in
`nvfuser::pointwise_utils::DomainMap::isValidReference`

The added check in this PR checks that all source producer IterDomain
producing the IterDomain on outputs are covered by reference tv. This is
safe for pointwise scheduler, since the scheduler checks that there's no
reversible view present in the fusion.

The check is optional and is disabled by transpose scheduler, where the
reference_tv is not supposed to cover the entire fusion, but rather a
subset of fusion IO tensors. We should extent that in future PRs.

---------

Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 18, 2024
1 parent ffd186e commit 01687df
Show file tree
Hide file tree
Showing 8 changed files with 708 additions and 16 deletions.
4 changes: 2 additions & 2 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {

int64_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims);
}

for (auto out : output_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims);
}

// If everything is zero dim tensors, just return.
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/pointwise_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TensorView* PointwiseDomainMap::findReferenceTensor(
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int64_t n_dims = pointwise_utils::nRootDims(output_tv);
int64_t n_dims = nLogicalDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace pointwise_utils {

// Returns number of non-reduction/non-broadcas/non-device dims in logical
// domain
inline int64_t nRootDims(const TensorView* tv) {
inline int64_t nLogicalDims(const TensorView* tv) {
auto logical_dom = tv->getLogicalDomain();
int64_t tv_n_dims = 0;
for (auto dim : logical_dom) {
Expand Down
147 changes: 146 additions & 1 deletion csrc/scheduler/tools/domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,137 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv)
return in_concrete_ids.empty();
}

// Note: ideally we would want to check that reference_tv contains all iter
// domains in target_tv, so that transformation applied on reference_tv can be
// propagated to target_tv. But we don't have an easy way to check that. Instead
// of that, this function checks that all source iter domains involved in
// transformation on target_tv is covered by reference_tv. Source iter domains
// of TensorViews are IDs that doesn't have an definition and are producers of
// any IDs on the logical domain of the given TensorView.
//
// ------
//
// e.g 0.
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T192 [i0, b3(ex), i1] = expand(T185)
// T198 [i0, b3(ex)*i1] = reshape(T192)
// output(T34)
// output(T198)
//
// if we consider taking T34 as reference_tv. T198 is the target_tv. We can't
// replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1
// can't be reversed. The check in this function would give us T34 with source
// i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in
// T34. Hence we'll reject this reference_tv.
//
// ------
//
// e.g 1.
// T0 [i0, i1]
// T1 [i2, i0, i1]
// T2 [i0*i1] = reshape(T0)
// T3 [b3, i0, i1] = broadcast(T0)
// T4 [i2, i0, i1] = add(T1, T3)
// output(T2)
// output(T4)
//
// the example above should be able to pick T4 as reference_tv. T2's source i0,
// i1 are both contained by the source of T4, so this example could be scheduled
// as a single fusion.
bool DomainMap::areAllTargetIdsCoveredBy(
TensorView* target_tv,
TensorView* reference_tv) const {
auto get_source_iter_domains = [this](TensorView* tv) {
// traverse back to collect all disjoint set producer IDs for each ID in the
// logical domain of tv.
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
all_producer_sets;
std::for_each(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
[&](IterDomain* tv_logical_id) {
all_producer_sets.pushBack(
ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT));
});
all_producer_sets.pushBack(
ca_map_.getAllDisjointSetProducers(all_producer_sets));

std::vector<IterDomain*> source_ids;
// filtering all producer IDs with empty definition to get source iter
// domains
std::for_each(
all_producer_sets.vector().begin(),
all_producer_sets.vector().end(),
[&source_ids,
this](const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>&
producer_set_ptr) {
IterDomain* producer_id = producer_set_ptr->front();
if (ca_map_.uniqueExactDefinitions(producer_id).empty()) {
source_ids.push_back(producer_id);
}
});
return source_ids;
};

// this contains all source iter domain that's covered by reference_tv, so
// it's safe for target_tv to have them.
std::unordered_set<IterDomain*> covered_source_ids;
for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) {
covered_source_ids.insert(source_id_ref);
}
// It's safe to have unmapped broadcast IterDomain. There're quite a few tests
// expecting pointwise scheduler to handle this pattern
for (IterDomain* id_out : target_tv->getLogicalDomain()) {
if (id_out->isBroadcast()) {
NVF_ERROR(
id_out->definition() == nullptr ||
id_out->definition()->isA<Resize>());

// Note that ideally we should also be able to handle merge/split on
// broadcast IDs, so we should really move this skip inside the loop below
// `get_source_iter_domains(target_tv)` and skip broadcast source IDs.
// currently we have the issue that split/merge does not preserve expanded
// broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126
covered_source_ids.insert(id_out);
}
}
// Note: there's certain cases where it's safe to have dangling IDs,
// e.g
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T192 [i0, b3(ex), i1] = expand(T185)
// It's safe to propagate T34 to T192, since b3(ex) is not involved in the
// propagation. But this isn't generally safe. If the above example is changed
// to e.g
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T186 [i0, i4, i1] = ones({i0, i4, i1})
// T193 [i0, i4, i1] = add(T185, T186)
// It's unsafe to propagate from T34 to T193, see issue
// https://github.com/NVIDIA/Fuser/issues/3542

// Check all source iter domain involved in producing target_tv
for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) {
// NOTE: we use concrete id instead. This allows us to link indirect
// broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1]
// T3[i0, i9] = pad(T0[i0, b0])
// We have i9 in T3
// -> source ID b0
// -> concrete map to i1
// So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1`
auto concrete_source_id_out =
ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE);
// if we find any source_id_out that's not contained, it's possible our
// propagation would fail since transformation involving this iter domain
// can't be resolved.
if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) {
return false;
}
}
return true;
}

// Reference domains must exactly match with the input domains. See
// also PR #661
IterDomain* DomainMap::getMappedInputConcreteID(
Expand Down Expand Up @@ -228,7 +359,7 @@ IterDomain* DomainMap::anyMapped(
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
// The reference tensor must map to all the iterDomains in each input and output
bool DomainMap::isValidReference(TensorView* tv) const {
for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
Expand All @@ -240,6 +371,20 @@ bool DomainMap::isValidReference(TensorView* tv) const {
return false;
}
}
// The check on outputs are optional, transpose scheduler might propose a
// secondary reference that only applies to a subset of IO tensors. Ideally we
// should have a more robust check and consider the IO groups instead of
// blindly skip outputs.
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
// no need to check for self.
if (output_tv == tv) {
continue;
}
if (!areAllTargetIdsCoveredBy(output_tv, tv)) {
return false;
}
}
return true;
}

Expand Down
9 changes: 8 additions & 1 deletion csrc/scheduler/tools/domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,21 @@ class DomainMap {
}

// Determine if a TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
// The reference tensor must map to all the iterDomains in each input and
// output.
bool isValidReference(TensorView* tv) const;

protected:
// Determine if all IterDomains are mapped between input and the given tvs
bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv)
const;

// Determine if all source IterDomains in target_tv are contained by the
// reference_tv, this ensures transformations from reference_tv can be
// propagated to target_tv
bool areAllTargetIdsCoveredBy(TensorView* target_tv, TensorView* reference_tv)
const;

virtual IterDomain* getMappedInputConcreteID(
const std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;
Expand Down
18 changes: 13 additions & 5 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,16 @@ class TransposeDomainMap : public scheduler_tools::DomainMap {
TensorView* result = nullptr;
int64_t max_dims = -1;
for (auto tv : group) {
// since transpose scheduler have different set of reference, we skip IDs
// coverage check of the reference on outputs of the fusion. Note that
// this is not ideal, we would want to instead have reference tensor
// checked against all its target IO tensors.
// TODO: open an issue for this one. transpose scheduler is not supposed
// to reuse pointwise_utils::DomainMap::isValidRefrence. This function is
// too restrictive and doesn't align well with the scheme of transpose
// scheduler
if (isValidReference(tv)) {
int64_t dims = (int64_t)pointwise_utils::nRootDims(tv);
int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv);
if (dims > max_dims) {
result = tv;
max_dims = dims;
Expand Down Expand Up @@ -992,12 +1000,12 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
<< "max_io_dtype_size: " << max_io_dtype_size << "\n"
<< "group 1: " << ir_utils::toString(grouped_inputs_outputs[0])
<< "\n"
<< "reference1: " << reference1 << "\n"
<< "reference1: " << reference1->toString() << "\n"
<< "inner_most_id1 position: " << inner_most_pos1_in_ref1
<< " (in reference 1)\n"
<< "group 2: " << ir_utils::toString(grouped_inputs_outputs[1])
<< "\n"
<< "reference2: " << reference2 << "\n"
<< "reference2: " << reference2->toString() << "\n"
<< "inner_most_id2 position: " << inner_most_pos2_in_ref1
<< " (in reference 1)" << std::endl;
if (hasSmallTransposeDimensions(tparams)) {
Expand Down Expand Up @@ -1047,11 +1055,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {

int64_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims);
}

for (auto out : output_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims);
}

// If everything is zero dim tensors, just return.
Expand Down
63 changes: 63 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3993,4 +3993,67 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {0});

// Reorder the accumulator as [M, N, K]
// [K, M, N] -> [M, N, K]
tv2->reorder({{-3, -1}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({K, M, 1}, options);
auto b_ref = at::randn({K, 1, N}, options);
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 256, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_256_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.splitk_factor = 1;
mparams.use_smem_epilogue = true;
mparams.cluster_dims = {2, 1, 1};
mparams.promote_prologue_smem_reuse = true;

std::cout << mparams.toString() << std::endl;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty());
auto cg_outputs = ke.run(inputs);
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel()));

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

} // namespace nvfuser
Loading

0 comments on commit 01687df

Please sign in to comment.