Skip to content

Commit

Permalink
Merge pull request #20119 from ThomasHoffmann77/20240314161733_new_pr…
Browse files Browse the repository at this point in the history
…_jax0425

{tools}[gfbf/2023a] jax v0.4.25 w/ CUDA 12.1.1
  • Loading branch information
lexming authored Jul 24, 2024
2 parents c591ee1 + c6518ee commit 3ea4e45
Showing 1 changed file with 144 additions and 0 deletions.
144 changes: 144 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.25'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://pypi.python.org/pypi/jax'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2023a'}
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
('Bazel', '6.3.1'),
('pytest-xdist', '3.3.1'),
# git 2.x required to fetch repository 'io_bazel_rules_docker'
('git', '2.41.0', '-nodocs'),
('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py
('poetry', '1.5.1'),
('pybind11', '2.11.1'),
]

dependencies = [
('CUDA', '12.1.1', '', SYSTEM),
('cuDNN', '8.9.2.26', versionsuffix, SYSTEM),
('NCCL', '2.18.3', versionsuffix),
('zlib', '1.2.13'),
('Python', '3.11.3'),
('SciPy-bundle', '2023.07'),
('flatbuffers-python', '23.5.26'),
('ml_dtypes', '0.3.2'),
]

# downloading xla and other tarballs to avoid that Bazel downloads it during the build
# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl
local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4'
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'

local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" '
local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" '


# Some tests require an isolated run:
local_isolated_tests = [
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
]
# deliberately not testing in parallel, as that results in (additional) failing tests;
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
# see https://github.com/google/jax/issues/7323 and
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
local_test_exports = [
"NVIDIA_TF32_OVERRIDE=0",
"CUDA_VISIBLE_DEVICES=0",
"XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"JAX_ENABLE_X64=true",
]
local_test = ''.join(['export %s;' % x for x in local_test_exports])
# run all tests at once except for local_isolated_tests:
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
# run remaining local_isolated_tests separately:
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])

use_pip = True

default_easyblock = 'PythonPackage'
default_component_specs = {
'sources': [SOURCE_TAR_GZ],
'source_urls': [PYPI_SOURCE],
'start_dir': '%(name)s-%(version)s',
'use_pip': True,
'sanity_pip_check': True,
'download_dep_fail': True,
}

components = [
('absl-py', '2.1.0', {
'options': {'modulename': 'absl'},
'checksums': ['7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff'],
}),
('jaxlib', version, {
'sources': [
'%(name)s-v%(version)s.tar.gz',
{
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit,
'extract_cmd': local_extract_cmd,
},
{
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit,
'extract_cmd': local_extract_cmd,
},
],
'source_urls': [
'https://github.com/google/jax/archive/',
'https://github.com/tensorflow/runtime/archive',
'https://github.com/openxla/xla/archive'
],
'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'],
'checksums': [
{'jaxlib-v0.4.25.tar.gz':
'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'},
{'xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4.tar.gz':
'8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'},
{'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
{'jax-0.4.25_fix-pybind11-systemlib.patch':
'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'},
],
'start_dir': 'jax-jaxlib-v%(version)s',
# Avoid warning (treated as error) in upb/table.c
'buildopts': local_repo_opt + ' --bazel_options="--copt=-Wno-maybe-uninitialized"'
}),
]

exts_list = [
(name, version, {
'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'],
'runtest': local_test,
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
'checksums': [
{'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'},
{'jax-0.4.25_fix_env_test_no_log_spam.patch':
'a18b5f147569d9ad41025124333a0f04fd0d0e0f9e4309658d7f6b9b838e2e2a'},
],
}),
]

sanity_pip_check = True

moduleclass = 'tools'

0 comments on commit 3ea4e45

Please sign in to comment.