Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port manual changes back to prehip files #9

Merged
merged 17 commits into from
Jan 29, 2025

Conversation

GMNGeoffrey
Copy link
Collaborator

@GMNGeoffrey GMNGeoffrey commented Jan 18, 2025

Modify the .prehip files so that they yield the desired source post hipification.

With these changes, running script/hipify-inplace.sh results in no diff.

Also needed to flip the conditional on the preprocessor define for DGL_WARP_SIZE so that it tests for DGL_USE_ROCM instead of DGL_USE_CUDA. We replace DGL_USE_CUDA with DGL_USE_ROCM in hipification so we need to do our check on DGL_USE_ROCM instead.

I made hipify-inplace.sh run hipify-tensoradapter.py as well, so we only had one command to run and added unhipify.sh that will copy prehip files back to the original locations.

GMNGeoffrey and others added 11 commits January 16, 2025 13:59
This is just the output of the hipify-inplace.sh and
hipify-tensoradapter.py scripts with no further modifications. I think
it's easier to review changes to the actual HIP source rather than
trying to thing about what the hipify script will do and since there are
a fair number of changes to review, that seems worth it. In the end we
should have a bunch of hip source files and .prehip cuda files that they
can be generated from. Then we can handle organization however we want:
restore the originals and have hipification be part of a build process,
have the hip versions on a separate branch, etc.

I'll check in the .prehip files in a separate commit to keep things a
bit cleaner.
These all get the .prehip extension appended.
In my porting, I was finding it really annoying that everything in DGL
was hardcoded to the directory build/ and that it created build/
directories in various source subdirectories (which meant cleaning and
rebuilding was fraught). I modified things so that all sub-builds happen
in the main build directory.

There were also some bugs in the shell scripts and I cleaned them up a
bit to make them more robust.

Not all of this is strictly required for ROCM to build, so we might want
to strip it out. I already stripped out various warning silencing for
that reason.
This is what I get for not copying exactly what I had.
There are definitely still runtime issues, but this handles all of the
compilation failures (requires clang-19 and bleeding-edge ROCM). Mostly
straightforward. The ones that aren't.

- The AtomicFPOp was adapated from
[PyTorch](https://github.com/pytorch/pytorch/blob/bef103934a25d848838a7642a8d6a2f523e7dfff/aten/src/ATen/cuda/Atomic.cuh#L39).
- The handling of legacy cusparse is something I messed up originally.
  The various CUDA version checks need to be in sync because the legacy
  version creates a transposed output that then needs to be flipped
  back, so I factored out a shared macro.
- There is something weird where HIP doesn't like the logging function.
  I was never able to get it to work, but it doesn't seem like hugely
  important, so I think it's ok to punt for now.

A couple of these changes might not be strictly necessary to make the
build work (like adding `__HIP_DEVICE_COMPILE__` to some of the
`__CUDA_ARCH__` checks) because I grabbed changes from my fully working
draft by file and just reverted the obviously more complicated ones. It
didn't seem worth reverting these uncomplicated ones.
Maybe these only show up with clang-19, but they made it really hard to look for actual problems. Hopefully this is something we could upstream. The only one that I think we really do need for hipification is `expansion-to-defined`, since that is actually undefined behavior.

I handled warnings in one of these ways;
1. Where the warning was from within the project and there was an obvious simple fix, I modified the source code:
     -  `-Wexpansion-to-defined`
     -  `-Wunused`
     -  `-Wmismatched-tags`
     -  `-Winconsistent-missing-override`
     -  `-Wparentheses`
     -  `-Wcuda-compat`
2. Where the warning was in a third_party library and a relatively harmless one to have in a third_party I silenced the warning for that library:
     - `-Wunused`
     - `-Wabsolute-value`
     - `-Wimplicit-exception-spec-mismatch`
     - `-Wdeprecated`
3. Explicit template instantiations are widespread across the project and look intentional, so I silenced the warning project-wide.
     - -Winstantiation-after-specialization
4. There are a few warnings that I think should probably be fixed in the project, but are are not egregious and aren't really relevant for our work, so I just silenced them in my own CMakePresets.json.
     - `-Wunneeded-internal-declaration`
     - `-Wdeprecated`
     - `-Wvla-cxx-extension`
5. "no case matching constant switch condition" can't be turned off, so I updated the third_party source (which is copied in-tree anyway). Maybe we can upstream that at some point.
6. There's one other warning that looks like it could be related to hipification. Now we can actually see it! Based on discussions like https://gitlab.com/libeigen/eigen/-/issues/1751, I'm not sure if these are a problem or not or how to best handle it, so I haven't fixed it.
     - `-Wpass-failed=transform-warning`
This isn't really the right way for portable HIP code, which should instead query warp size at runtime (see https://rocm.docs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#warpsize). But warp size is used as a compile-time constant in a number of places so that change is more invasive and for now we mostly care about gfx9*, so I'm punting on the right way.
The hip_bf16.h header is apparently not written to allow its inclusion in code not compiled by HIP, which seems like maybe a bug, so we should maybe file an issue with HIP.

One thing I'm not totally sure of is that I replaced usage of `__float2bfloat16_rn` to get zero, inf, and -inf bfloat16 with the [constant macros](https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__INTRINSIC__BFLOAT16__CONSTANTS.html). I'm not totally sure why DGL was specifically using the round-even variant here, but it seems like for starting points for sum, min, and max, the constants should work fine.
Everything in DGL is set up for CUDA and uses the `kDGLCUDA` device
type. Our hipification has modified the CUDA code to work with HIP,
rather than creating a whole separate device type. But external
libraries like PyTorch and DLPack correctly identify the device as ROCM,
so at the interfaces we need to do a conversion.

With this change, all of the Python tests that don't require Graphbolt
and multi-GPU distribution pass.

i.e. the following commands (DGL tests are set up in a kind of weird way
where you can't run all of these at once). This assumes you've built the
project in `build/` (or symlinked your build directory there, in my
case).

The GPU tests

```shell
bash ./script/run_pytest.sh -g \
    tests/python/pytorch/ \
    --ignore=tests/python/pytorch/graphbolt \
    --ignore=tests/python/pytorch/distributed \
    --deselect=tests/python/pytorch/dataloading/test_dataloader.py::test_distributed_dataloaders
bash ./script/run_pytest.sh -g \
    tests/python/common/ \
    --ignore=tests/python/common/test_partition.py
```

The CPU tests (for good measure)

```shell
bash ./script/run_pytest.sh -c \
    tests/python/pytorch/ \
    --ignore=tests/python/pytorch/graphbolt \
    --ignore=tests/python/pytorch/distributed \
    --deselect=tests/python/pytorch/dataloading/test_dataloader.py::test_distributed_dataloaders
bash ./script/run_pytest.sh -c \
    tests/python/common/ \
    --ignore=tests/python/common/test_partition.py
```
@GMNGeoffrey
Copy link
Collaborator Author

GMNGeoffrey commented Jan 18, 2025

@jeffdaily PTAL. These are the changes to review

I accidentally committed this. One of the QoL changes I made to DGL in
https://github.com/GMNGeoffrey/dgl/tree/hipify-inplace-squash beyond
hipification was making this test not write this file into the source
directory. I'll see if I can upstream it.
We turn `DGL_USE_CUDA` into `DGL_USE_ROCM` so we need to do our check
on `DGL_USE_ROCM` instead.
@GMNGeoffrey GMNGeoffrey changed the base branch from master to hipify-inplace January 22, 2025 17:53
@GMNGeoffrey GMNGeoffrey requested a review from jeffdaily January 22, 2025 18:03
@GMNGeoffrey GMNGeoffrey merged commit 0b6dfca into nod-ai:hipify-inplace Jan 29, 2025
GMNGeoffrey added a commit that referenced this pull request Jan 29, 2025
These changes are on the original DGL source code. These plus running `script/hipify-inplace.sh` yields the hipified version of DGL which is identical to the code currently in #9. The version in this commit should still run with CUDA.

This obviously shouldn't be merged into the same branch as PRs #1 through #9. The idea is that this would be the PR we would need for upstream (although I'm guessing they would actually want it in smaller chunks).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant