Skip to content

Commit

Permalink
upgrade pytorch to 2.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Oct 17, 2024
1 parent 1c742ae commit b3176e4
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 16 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ jobs:
matrix:
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
cuda: ["11.8", "12.1", "12.4"]
torch: ["2.2.2", "2.3.1", "2.4.1"]
torch: ["2.3.1", "2.4.1", "2.5.0"]
exclude:
- cuda: "12.4"
torch: "2.3.1"
- cuda: "12.4"
torch: "2.2.2"
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/package_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
matrix:
python: ["3.12"]
cuda: ["12.4"]
torch: ["2.4.1"]
torch: ["2.5.0"]
runs-on: [self-hosted, linux, build]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
matrix:
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
cuda: ["12.1"]
torch: ["2.4.1"]
torch: ["2.5.0"]
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
matrix:
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
cuda: ["12.1"]
torch: ["2.4.1"]
torch: ["2.5.0"]
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
20 changes: 10 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,25 +194,25 @@ if (DEFINED ENV{LIBTORCH_ROOT})
else()
include(FetchContent)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.4)
# download libtorch 2.4.1 with cuda 12.4 from pytorch.org
# download libtorch 2.5.0 with cuda 12.4 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu124.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.4.1%2Bcu124.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.5.0%2Bcu124.zip")
endif()
elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.1)
# download libtorch 2.4.1 with cuda 12.1 from pytorch.org
# download libtorch 2.5.0 with cuda 12.1 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu121.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu121.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.4.1%2Bcu121.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.5.0%2Bcu121.zip")
endif()
elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.8)
# download libtorch 2.4.1 with cuda 11.8 from pytorch.org
# download libtorch 2.5.0 with cuda 11.8 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu118.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu118.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-2.4.1%2Bcu118.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-2.5.0%2Bcu118.zip")
endif()
else()
# error out if cuda version is not supported
Expand All @@ -232,7 +232,7 @@ else()
FetchContent_MakeAvailable(libtorch)

find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)
message(STATUS "Downloading and using libtorch 2.4.1 for cuda ${CUDA_VERSION} at ${libtorch_SOURCE_DIR}")
message(STATUS "Downloading and using libtorch 2.5.0 for cuda ${CUDA_VERSION} at ${libtorch_SOURCE_DIR}")
endif()

# check if USE_CXX11_ABI is set correctly
Expand Down

0 comments on commit b3176e4

Please sign in to comment.