-
Notifications
You must be signed in to change notification settings - Fork 2.9k
64 lines (60 loc) · 2.43 KB
/
wheel_tests_nightly_release.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# CI - Wheel Tests (Nightly/Release)
#
# This workflow builds JAX artifacts and runs CPU/CUDA tests.
#
# It orchestrates the following:
# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was
# built by internal CI jobs and runs CPU tests.
# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
# artifacts that were built by internal CI jobs and runs the CUDA tests.
name: CI - Wheel Tests (Nightly/Release)
on:
workflow_dispatch:
inputs:
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://jax-nightly-release-transient/nightly/latest'
type: string
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
run-pytest-cpu:
uses: ./.github/workflows/pytest_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy of our internal CI jobs
# that build the wheels.
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
python: ["3.10","3.11", "3.12", "3.13"]
enable-x64: [0]
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
# GCS bucket.
install-jax-current-commit: 0
gcs_download_uri: ${{inputs.gcs_download_uri}}
run-pytest-cuda:
uses: ./.github/workflows/pytest_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy of our internal CI jobs
# that build the wheels.
runner: ["linux-x86-g2-48-l4-4gpu"]
python: ["3.10","3.11", "3.12", "3.13"]
cuda: ["12.3", "12.1"]
enable-x64: [0]
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
# GCS bucket.
install-jax-current-commit: 0
gcs_download_uri: ${{inputs.gcs_download_uri}}