Skip to content

Temp XLA JAX T5x Perf test workflow #213

Temp XLA JAX T5x Perf test workflow

Temp XLA JAX T5x Perf test workflow #213

name: Temp XLA JAX T5x Perf test workflow
on:
workflow_dispatch:
inputs:
ARCHITECTURE:
type: string
required: true
TEST_T5X_PAXML:
type: string
description: Execute T5x and paxml build and perf tests, 0=none, 1=T5x only, 2=paxml only, 3=both?
default: '3'
required: false
BUILD_DATE:
type: string
description: Build date in YYYY-MM-DD format
required: false
default: NOT SPECIFIED
PUBLISH:
type: boolean
description: Publish dated images and update the 'latest' tag?
default: false
required: false
JAX_SRC_REF:
description: 'JAX source url#branch/commit SHA'
type: string
required: false
default: ''
XLA_SRC_REF:
description: 'XLA source url#branch/commit SHA'
type: string
required: false
default: ''
T5X_SRC_REF:
description: 'T5X source url#branch/commit SHA'
type: string
required: false
default: ''
PAXML_SRC_REF:
description: 'PAXML source url#branch/commit SHA'
type: string
required: false
default: ''
PRAXIS_SRC_REF:
description: 'PRAXIS source url#branch/commit SHA'
type: string
required: false
default: ''
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload container
jobs:
metadata:
runs-on: ubuntu-22.04
outputs:
BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }}
PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }}
steps:
- name: Set build date
id: date
shell: bash -x -e {0}
run: |
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d')
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT
- name: Determine whether results will be 'published'
id: if-publish
shell: bash -x -e {0}
run: |
echo "PUBLISH=${{ github.event_name == 'schedule' || inputs.PUBLISH }}" >> $GITHUB_OUTPUT
echo "JAX_SRC_REF: ${{ inputs.JAX_SRC_REF }}";
echo "XLA_SRC_REF: ${{ inputs.XLA_SRC_REF }}";
echo "T5X_SRC_REF: ${{ inputs.T5X_SRC_REF }}";
echo "PAXML_SRC_REF: ${{ inputs.PAXML_SRC_REF }}";
echo "PRAXIS_SRC_REF: ${{ inputs.PRAXIS_SRC_REF }}";
echo "TEST_T5X_PAXML: ${{ inputs.TEST_T5X_PAXML }}";
build-base:
uses: ./.github/workflows/_build_base.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BUMP_MANIFEST: false
JAX_SRC_REF: ${{ inputs.JAX_SRC_REF }}
XLA_SRC_REF: ${{ inputs.XLA_SRC_REF }}
T5X_SRC_REF: ${{ inputs.T5X_SRC_REF }}
PAXML_SRC_REF: ${{ inputs.PAXML_SRC_REF }}
PRAXIS_SRC_REF: ${{ inputs.PRAXIS_SRC_REF }}
secrets: inherit
build-jax:
needs: build-base
uses: ./.github/workflows/_build_jax.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-base.outputs.DOCKER_TAG }}
secrets: inherit
build-t5x:
needs: build-jax
if: inputs.ARCHITECTURE == 'amd64' && (inputs.TEST_T5X_PAXML == '1' || inputs.TEST_T5X_PAXML == '3') # T5X arm64 build is wip in PR 252
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: "artifact-t5x-build"
BADGE_FILENAME: "badge-t5x-build"
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: upstream-t5x
DOCKERFILE: .github/container/Dockerfile.t5x.${{ inputs.ARCHITECTURE }}
secrets: inherit
build-pax:
needs: build-jax
if: inputs.TEST_T5X_PAXML == '2' || inputs.TEST_T5X_PAXML == '3'
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: artifact-pax-build
BADGE_FILENAME: badge-pax-build
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: upstream-pax
DOCKERFILE: .github/container/Dockerfile.pax.${{ inputs.ARCHITECTURE }}
secrets: inherit
test-distribution:
runs-on: ubuntu-22.04
strategy:
matrix:
TEST_SCRIPT:
- extra-only-distribution.sh
- mirror-only-distribution.sh
- upstream-only-distribution.sh
fail-fast: false
steps:
- name: Print environment variables
run: env
- name: Set git login for tests
run: |
git config --global user.email "jax@nvidia.com"
git config --global user.name "JAX-Toolbox CI"
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v3
- name: Run integration test ${{ matrix.TEST_SCRIPT }}
run: bash rosetta/tests/${{ matrix.TEST_SCRIPT }}
test-upstream-t5x:
needs: build-t5x
if: inputs.ARCHITECTURE == 'amd64' && (inputs.TEST_T5X_PAXML == '1' || inputs.TEST_T5X_PAXML == '3') # arm64 runners n/a
uses: ./.github/workflows/_test_t5x.yaml
with:
T5X_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAG_FINAL }}
secrets: inherit
test-upstream-pax:
needs: build-pax
if: inputs.ARCHITECTURE == 'amd64' && (inputs.TEST_T5X_PAXML == '2' || inputs.TEST_T5X_PAXML == '3') # no images for arm64
uses: ./.github/workflows/_test_pax.yaml
with:
PAX_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAG_FINAL }}
secrets: inherit
publish-target-tags:
runs-on: ubuntu-22.04
outputs:
TARGET_TAGS: ${{ steps.tags.outputs.TARGET_TAGS}}
steps:
- id: tags
run: |
declare -a TARGET_IMAGE=("jax", "test-upstream-t5x")
declare -a FLAVOR=("mealkit" "final")
## now loop through the above array
JSON="{"
for target in "${TARGET_IMAGE[@]}";do
for flavor in "${FLAVOR[@]}"; do
CONTAINER_TAG=${flavor}
TAG_DATED=${flavor}
if [[ ${flavor} == "final" ]]; then
CONTAINER_TAG=latest
TAG_DATED=nightly
fi
JSON=$(echo ${JSON}\"${target}-${flavor}-container-tag\":\"${CONTAINER_TAG}\",)
JSON=$(echo ${JSON}\"${target}-${flavor}-tag-dated\":\"${TAG_DATED}\",)
done
done
JSON="${JSON::-1} }"
echo "TARGET_TAGS=${JSON}" | tee -a $GITHUB_OUTPUT
publish:
needs: [metadata, test-upstream-t5x, test-upstream-pax, publish-target-tags]
if: false # TODO: enable this after new image renaming proposal is approved
# if: ${{ !cancelled() && needs.metadata.outputs.PUBLISH }}
strategy:
fail-fast: false
matrix:
TARGET_IMAGE: [jax, test-upstream-t5x]
FLAVOR: [mealkit, final]
uses: ./.github/workflows/_publish_container.yaml
with:
SOURCE_IMAGE: |
${{ fromJson(needs.amd64.outputs.CONTAINER_TAGS)[format('tag-{0}-{1}', matrix.TARGET_IMAGE, matrix.FLAVOR)] }}
${{ fromJson(needs.arm64.outputs.CONTAINER_TAGS)[format('tag-{0}-{1}', matrix.TARGET_IMAGE, matrix.FLAVOR)] }}
TARGET_IMAGE: ${{ matrix.TARGET_IMAGE }}
TARGET_TAGS: |
type=raw,value=${{ fromJson(needs.publish-target-tags.outputs.TARGET_TAGS)[format('{0}-{1}-container-tag', matrix.TARGET_IMAGE, matrix.FLAVOR)] }},priority=500
type=raw,value=${{ fromJson(needs.publish-target-tags.outputs.TARGET_TAGS)[format('{0}-{1}-tag-dated', matrix.TARGET_IMAGE, matrix.FLAVOR)] }}-${{ needs.metadata.outputs.BUILD_DATE }},priority=500
finalize:
needs: [metadata, test-upstream-t5x, test-upstream-pax, publish-target-tags]
if: "!cancelled()"
uses: ./.github/workflows/_finalize.yaml
with:
PUBLISH_BADGE: false
secrets: inherit