Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TMem check the stride of outer dims #4070

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

TMem check the stride of outer dims #4070

wants to merge 9 commits into from

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Mar 13, 2025

Validating that the TMem ld/st is scheduled as a warp-collective is fundamentally the same problem as validating the vectorization of a slice op with vectorization factor 32. This validation requires us to check the inner dim size, as well as outer dim's strides. Currently, we are only checking the inner-dim's size. This PR adds the missing stride check.

See discussion: #4015 (comment)

Copy link

github-actions bot commented Mar 13, 2025

Review updated until commit 3b44b2f

Description

  • Added helper function isSharedMemory to check tensor memory types.

  • Updated needsPredicateSharedMemAccess to use isSharedMemory.

  • Modified getThreadParallelTypesMergedByContiguity to return strides.

  • Added stride checks for outer parallel types in TMem load/store.


Changes walkthrough 📝

Relevant files
Enhancement
predicate_elimination.cpp
Add shared memory check helper                                                     

csrc/device_lower/analysis/predicate_elimination.cpp

  • Added isSharedMemory function.
  • Updated needsPredicateSharedMemAccess to use isSharedMemory.
  • +11/-2   
    tensor_memory.cpp
    Enhance thread parallel types with strides                             

    csrc/device_lower/analysis/tensor_memory.cpp

  • Modified getThreadParallelTypesMergedByContiguity to return strides.
  • Added stride checks for outer parallel types in TMem load/store.
  • +62/-16 
    utils.cpp
    Add TMem load/store check                                                               

    csrc/device_lower/utils.cpp

  • Added isLdStTMem function to check for TMem load/store operations.
  • +8/-0     
    predicate_compute.cpp
    Skip predicate for TMem load/store                                             

    csrc/predicate_compute.cpp

    • Updated getInlinePredicate to skip predicate for TMem load/store.
    +4/-1     
    tensor_view.cpp
    Validate tensor memory type in setTMemDimSepPos                   

    csrc/tensor_view.cpp

    • Added check in setTMemDimSepPos to ensure tensor memory type.
    +3/-0     
    utils.h
    Declare TMem load/store check                                                       

    csrc/device_lower/utils.h

    • Added declaration for isLdStTMem.
    +4/-0     
    abstract_tensor.h
    Add reverse method to AbstractTensorWithInfo                         

    csrc/scheduler/tools/abstract_tensor.h

    • Added reverse method to AbstractTensorWithInfo.
    +5/-0     
    Tests
    test_memory.cpp
    Add TMem stride and memory type tests                                       

    tests/cpp/test_memory.cpp

  • Added tests for setTMemDimSepPos with non-TMem tensor.
  • Added tests for TMem load/store with incorrect stride.
  • Added tests for TMem load/store with inexact parallel types.
  • +120/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Stride Calculation

    Ensure that the stride calculation is correct and that it handles all edge cases, especially when the extent of a parallel type is 1.

    // The strides returned are the stride of each edge of the box in that 3D CTA
    // lattice. For example, if the parallel dimension sizes of the kernel are:
    //   TIDz: 32, TIDy: 8, TIDx: 8
    // and the loop domain is:
    //   I0: TIDz, extent 32
    //   I1: TIDy, extent 7
    //   I2: TIDx, extent 8
    // then the strides are [8*8, 8, 1].
    std::pair<AbstractTensor, std::vector<Val*>>
    getThreadParallelTypesMergedByContiguity(const Expr* expr) {
      auto& id_graph = GpuLower::current()->tensorIndexer().traversalGraph();
    Memory Type Check

    Verify that the memory type check in setTMemDimSepPos is comprehensive and that it handles all possible scenarios.

    void TensorView::setTMemDimSepPos(int64_t pos) {
      NVF_CHECK(
          getMemoryType() == MemoryType::Tensor,
          "TMem dimension separator is only supported for tensor memory");
    Test Coverage

    Ensure that the added tests cover all possible scenarios, including edge cases, and that they effectively validate the changes made in the PR.

    using TMemTestCompileOnly = NVFuserTest;
    
    TEST_F(TMemTestCompileOnly, SetTMemDimSepPosNonTMem) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeContigConcreteTensor({2, 33});
      fusion.addInput(tv0);
      auto tv1 = set(tv0);
      fusion.addOutput(tv1);
    
      EXPECT_THAT(
          [&]() { tv1->setTMemDimSepPos(-1); },
          ::testing::ThrowsMessage<nvfuser::nvfError>(::testing::HasSubstr(
              "TMem dimension separator is only supported for tensor memory")));
    }
    
    // Test that we are checking the stride of the "outer parallel types".
    // If in a kernel, the parallel dimension map is [TIDy, TIDx] = [2, 33],
    // But in the TMem load/store's loop domain, Ix (the ID parallelized on TIDx)
    // have extent 32. Then we will generate code like:
    //   if (threadIdx.x < 32) {
    //     tmem::load
    //   }
    // For threadIdx.y == 0, it is correct. But for threadIdx.y == 1, it is wrong
    // because we are using the thread id 33-65 for the load, which is not a warp.
    TEST_F(TMemTestCompileOnly, WrongStride) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeContigConcreteTensor({2, 33});
      fusion.addInput(tv0);
      auto tv1 = set(tv0); // gmem
      auto tv2 = set(tv1); // register
      auto tv3 = set(tv2); // tmem
      auto tv4 = set(tv3); // register
      auto tv5 = set(tv4); // gmem
      fusion.addOutput(tv5);
    
      tv1->setMemoryType(MemoryType::Global);
      tv3->setMemoryType(MemoryType::Tensor);
      tv3->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::StTMem);
      tv4->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::LdTMem);
    
      // [TIDy{2}, TIDx{33}]
      tv1->axis(0)->parallelize(ParallelType::TIDy);
      tv1->axis(1)->parallelize(ParallelType::TIDx);
    
      // [TIDy{2}, Serial{2}, TIDx{32}]
      for (auto tv : {tv2, tv3, tv4, tv5}) {
        tv->split(1, 32);
        tv->axis(0)->parallelize(ParallelType::TIDy);
        tv->axis(-1)->parallelize(ParallelType::TIDx);
      }
    
      tv3->setAllocationDomain(tv3->getLoopDomain(), true);
      tv3->setTMemDimSepPos(-1);
    
      inlineMost();
    
      KernelExecutor ke;
    
      EXPECT_THAT(
          [&]() { ke.compile(&fusion); },
          ::testing::ThrowsMessage<nvfuser::nvfError>(::testing::HasSubstr(
              "Invalid data access pattern in TMem load/store: "
              "Outer parallel types' strides must be a multiple of 32.")));
    }
    
    // This test is a variant of the WrongStride test, but this test is valid.
    // Test a case where the parallel types are not exact. The parallel dimension
    // map is [TIDy, TIDx] = [2, 33], but in the TMem load/store's loop domain,
    // we have Iy{1}, Ix{32}. the generated code will be like:
    //   if (threadIdx.x < 32 && threadIdx.y < 1) {
    //     tmem::load
    //   }
    // This is valid because we are using a whole warp for the load.
    TEST_F(TMemTest, InexactParallelType) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeContigConcreteTensor({2, 33});
      fusion.addInput(tv0);
      auto tv1 = set(tv0); // gmem
      auto tv2 = set(tv1); // register
      auto tv3 = set(tv2); // tmem
      auto tv4 = set(tv3); // register
      auto tv5 = set(tv4); // gmem
      fusion.addOutput(tv5);
    
      tv1->setMemoryType(MemoryType::Global);
      tv3->setMemoryType(MemoryType::Tensor);
      tv3->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::StTMem);
      tv4->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::LdTMem);
    
      // [TIDy{2}, TIDx{33}]
      tv1->axis(0)->parallelize(ParallelType::TIDy);
      tv1->axis(1)->parallelize(ParallelType::TIDx);
    
      // [Serial{2}, TIDy{1}, Serial{2}, TIDx{32}]
      for (auto tv : {tv2, tv3, tv4, tv5}) {
        tv->split(1, 32);
        tv->split(0, 1);
        tv->axis(1)->parallelize(ParallelType::TIDy);
        tv->axis(-1)->parallelize(ParallelType::TIDx);
      }
    
      tv3->setAllocationDomain(tv3->getLoopDomain(), true);
      tv3->setTMemDimSepPos(-1);
    
      inlineMost();
    
      KernelExecutor ke;
      ke.compile(&fusion);
      auto t0 = at::randn(
          {2, 33}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0));
      auto cg_outputs = ke.run({t0});
      testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
    }
    
    using LdMatrixTestParam = std::tuple<MmaMacro, MmaOperand>;

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I decided to not put the new tests into the tutorial. They are too complicated to be a good material for education. I want the reader of the tutorial to focus on the most important and basic concepts.

    @zasdfgbnm zasdfgbnm marked this pull request as ready for review March 13, 2025 06:21
    @zasdfgbnm zasdfgbnm requested a review from naoyam March 13, 2025 06:21
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant