Skip to content

Commit

Permalink
[CI] Don't include ATen/cuda/CUDAContext.h to avoid cusparse.h
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 6, 2024
1 parent 5d39d51 commit 1f502d4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
1 change: 0 additions & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ jobs:
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
sub-packages: '["nvcc"]'
non-cuda-sub-packages: '["libcublas", "libcusparse"]'

- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
Expand Down
2 changes: 1 addition & 1 deletion causal_conv1d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "1.5.0.post3"
__version__ = "1.5.0.post4"

from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
11 changes: 4 additions & 7 deletions csrc/causal_conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/python.h>
#include <vector>

Expand Down Expand Up @@ -221,8 +221,7 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
at::cuda::CUDAGuard device_guard{x.device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
Expand Down Expand Up @@ -308,8 +307,7 @@ causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
at::cuda::CUDAGuard device_guard{x.device()};

at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
at::Tensor dbias;
Expand Down Expand Up @@ -462,8 +460,7 @@ causal_conv1d_update(const at::Tensor &x,
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
at::cuda::CUDAGuard device_guard{x.device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
Expand Down

0 comments on commit 1f502d4

Please sign in to comment.