-
Notifications
You must be signed in to change notification settings - Fork 706
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20119 from ThomasHoffmann77/20240314161733_new_pr…
…_jax0425 {tools}[gfbf/2023a] jax v0.4.25 w/ CUDA 12.1.1
- Loading branch information
Showing
1 changed file
with
144 additions
and
0 deletions.
There are no files selected for viewing
144 changes: 144 additions & 0 deletions
144
easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild | ||
# Author: Denis Kristak | ||
# Updated by: Alex Domingo (Vrije Universiteit Brussel) | ||
# Updated by: Thomas Hoffmann (EMBL Heidelberg) | ||
easyblock = 'PythonBundle' | ||
|
||
name = 'jax' | ||
version = '0.4.25' | ||
versionsuffix = '-CUDA-%(cudaver)s' | ||
|
||
homepage = 'https://pypi.python.org/pypi/jax' | ||
description = """Composable transformations of Python+NumPy programs: | ||
differentiate, vectorize, JIT to GPU/TPU, and more""" | ||
|
||
toolchain = {'name': 'gfbf', 'version': '2023a'} | ||
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] | ||
|
||
builddependencies = [ | ||
('Bazel', '6.3.1'), | ||
('pytest-xdist', '3.3.1'), | ||
# git 2.x required to fetch repository 'io_bazel_rules_docker' | ||
('git', '2.41.0', '-nodocs'), | ||
('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py | ||
('poetry', '1.5.1'), | ||
('pybind11', '2.11.1'), | ||
] | ||
|
||
dependencies = [ | ||
('CUDA', '12.1.1', '', SYSTEM), | ||
('cuDNN', '8.9.2.26', versionsuffix, SYSTEM), | ||
('NCCL', '2.18.3', versionsuffix), | ||
('zlib', '1.2.13'), | ||
('Python', '3.11.3'), | ||
('SciPy-bundle', '2023.07'), | ||
('flatbuffers-python', '23.5.26'), | ||
('ml_dtypes', '0.3.2'), | ||
] | ||
|
||
# downloading xla and other tarballs to avoid that Bazel downloads it during the build | ||
# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl | ||
local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' | ||
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' | ||
|
||
local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives' | ||
local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" ' | ||
local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' | ||
local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" ' | ||
|
||
|
||
# Some tests require an isolated run: | ||
local_isolated_tests = [ | ||
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp', | ||
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', | ||
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + | ||
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', | ||
] | ||
# deliberately not testing in parallel, as that results in (additional) failing tests; | ||
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, | ||
# see https://github.com/google/jax/issues/7323 and | ||
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; | ||
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; | ||
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores; | ||
local_test_exports = [ | ||
"NVIDIA_TF32_OVERRIDE=0", | ||
"CUDA_VISIBLE_DEVICES=0", | ||
"XLA_PYTHON_CLIENT_ALLOCATOR=platform", | ||
"JAX_ENABLE_X64=true", | ||
] | ||
local_test = ''.join(['export %s;' % x for x in local_test_exports]) | ||
# run all tests at once except for local_isolated_tests: | ||
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) | ||
# run remaining local_isolated_tests separately: | ||
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests]) | ||
|
||
use_pip = True | ||
|
||
default_easyblock = 'PythonPackage' | ||
default_component_specs = { | ||
'sources': [SOURCE_TAR_GZ], | ||
'source_urls': [PYPI_SOURCE], | ||
'start_dir': '%(name)s-%(version)s', | ||
'use_pip': True, | ||
'sanity_pip_check': True, | ||
'download_dep_fail': True, | ||
} | ||
|
||
components = [ | ||
('absl-py', '2.1.0', { | ||
'options': {'modulename': 'absl'}, | ||
'checksums': ['7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff'], | ||
}), | ||
('jaxlib', version, { | ||
'sources': [ | ||
'%(name)s-v%(version)s.tar.gz', | ||
{ | ||
'download_filename': '%s.tar.gz' % local_xla_commit, | ||
'filename': 'xla-%s.tar.gz' % local_xla_commit, | ||
'extract_cmd': local_extract_cmd, | ||
}, | ||
{ | ||
'download_filename': '%s.tar.gz' % local_tfrt_commit, | ||
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit, | ||
'extract_cmd': local_extract_cmd, | ||
}, | ||
], | ||
'source_urls': [ | ||
'https://github.com/google/jax/archive/', | ||
'https://github.com/tensorflow/runtime/archive', | ||
'https://github.com/openxla/xla/archive' | ||
], | ||
'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'], | ||
'checksums': [ | ||
{'jaxlib-v0.4.25.tar.gz': | ||
'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'}, | ||
{'xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4.tar.gz': | ||
'8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'}, | ||
{'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz': | ||
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, | ||
{'jax-0.4.25_fix-pybind11-systemlib.patch': | ||
'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'}, | ||
], | ||
'start_dir': 'jax-jaxlib-v%(version)s', | ||
# Avoid warning (treated as error) in upb/table.c | ||
'buildopts': local_repo_opt + ' --bazel_options="--copt=-Wno-maybe-uninitialized"' | ||
}), | ||
] | ||
|
||
exts_list = [ | ||
(name, version, { | ||
'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'], | ||
'runtest': local_test, | ||
'source_tmpl': '%(name)s-v%(version)s.tar.gz', | ||
'source_urls': ['https://github.com/google/jax/archive/'], | ||
'checksums': [ | ||
{'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'}, | ||
{'jax-0.4.25_fix_env_test_no_log_spam.patch': | ||
'a18b5f147569d9ad41025124333a0f04fd0d0e0f9e4309658d7f6b9b838e2e2a'}, | ||
], | ||
}), | ||
] | ||
|
||
sanity_pip_check = True | ||
|
||
moduleclass = 'tools' |