Skip to content

Commit

Permalink
[CI] Change torch 2.6.0.dev20241010 to dev20241001 for nvcr 24.11
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 6, 2024
1 parent 802d59d commit 715ff10
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
os: [ubuntu-20.04]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241010']
cuda-version: ['11.8.0', '12.4.1']
cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
Expand Down Expand Up @@ -80,6 +80,7 @@ jobs:
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
Expand All @@ -106,8 +107,6 @@ jobs:
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
sub-packages: '["nvcc"]'

- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
Expand All @@ -127,7 +126,10 @@ jobs:
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# --no-deps because we can't install old versions of pytorch-triton
pip install jinja2
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
Expand Down
2 changes: 1 addition & 1 deletion causal_conv1d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "1.5.0.post5"
__version__ = "1.5.0.post6"

from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,9 @@ def get_wheel_url():
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.4
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.4")
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
cuda_version = f"{torch_cuda_version.major}"

gpu_compute_version = hip_version if HIP_BUILD else cuda_version
Expand Down

0 comments on commit 715ff10

Please sign in to comment.