Skip to content

Commit

Permalink
Merge pull request #17421 from branfosj/20230223170804_new_pr_jax044
Browse files Browse the repository at this point in the history
{tools}[foss/2022a] jax v0.4.4 w/ Python 3.10.4 w/ CUDA 11.7.0
  • Loading branch information
boegel authored Mar 1, 2023
2 parents e7aa6b6 + dea0cb6 commit f4ff8b3
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
116 changes: 116 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.4-foss-2022a-CUDA-11.7.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.4'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://pypi.python.org/pypi/jax'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'foss', 'version': '2022a'}

builddependencies = [
('Bazel', '5.1.1'),
('pytest-xdist', '2.5.0'),
# git 2.x required to fetch repository 'io_bazel_rules_docker'
('git', '2.36.0', '-nodocs'),
('matplotlib', '3.5.2'), # required for tests/lobpcg_test.py
]

dependencies = [
('CUDA', '11.7.0', '', SYSTEM),
('cuDNN', '8.4.1.50', versionsuffix, SYSTEM),
('NCCL', '2.12.12', versionsuffix),
('Python', '3.10.4'),
('SciPy-bundle', '2022.05'),
('flatbuffers-python', '2.0'),
]

# downloading TensorFlow tarball to avoid that Bazel downloads it during the build
# note: this *must* be the exact same commit as used in WORKSPACE
local_tf_commit = '43e9d313548ded301fa54f25a4192d3bcb123330'
local_tf_dir = 'tensorflow-%s' % local_tf_commit
local_tf_builddir = '%(builddir)s/' + local_tf_dir

# replace remote TensorFlow repository with the local one from EB
local_jax_prebuildopts = "sed -i -f jaxlib_local-tensorflow-repo.sed WORKSPACE && "
local_jax_prebuildopts += "sed -i 's|EB_TF_REPOPATH|%s|' WORKSPACE && " % local_tf_builddir

use_pip = True

default_easyblock = 'PythonPackage'
default_component_specs = {
'sources': [SOURCE_TAR_GZ],
'source_urls': [PYPI_SOURCE],
'start_dir': '%(name)s-%(version)s',
'use_pip': True,
'sanity_pip_check': True,
'download_dep_fail': True,
}

components = [
('absl-py', '1.4.0', {
'options': {'modulename': 'absl'},
'checksums': ['d2c244d01048ba476e7c080bd2c6df5e141d211de80223460d5b3b8a2a58433d'],
}),
('jaxlib', '0.4.4', {
'sources': [
'%(name)s-v%(version)s.tar.gz',
{
'download_filename': '%s.tar.gz' % local_tf_commit,
'filename': 'tensorflow-%s.tar.gz' % local_tf_commit,
}
],
'source_urls': [
'https://github.com/google/jax/archive/',
'https://github.com/tensorflow/tensorflow/archive/'
],
'patches': [
('jaxlib_local-tensorflow-repo.sed', '.'),
('TensorFlow-2.7.0_cuda-noncanonical-include-paths.patch', '../' + local_tf_dir),
],
'checksums': [
{'jaxlib-v0.4.4.tar.gz': '881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806'},
{'tensorflow-43e9d313548ded301fa54f25a4192d3bcb123330.tar.gz':
'23aae276b2705bfbdaea3c472da24130598f13ac0439cfb7149befb781d97a8f'},
{'jaxlib_local-tensorflow-repo.sed': 'abb5c3b97f4e317bce9f22ed3eeea3b9715365818d8b50720d937e2d41d5c4e5'},
{'TensorFlow-2.7.0_cuda-noncanonical-include-paths.patch':
'0a759010c253d49755955cd5f028e75de4a4c447dcc8f5a0d9f47cce6881a9db'},
],
'start_dir': 'jax-jaxlib-v%(version)s',
'prebuildopts': local_jax_prebuildopts,
}),
]

exts_list = [
('opt_einsum', '3.3.0', {
'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'],
}),
('etils', '1.0.0', {
'checksums': ['d10982f7702422bea8635d5284b8bed629f51919fc122ac1e1e4abf45ec8f785'],
}),
(name, version, {
'patches': [
'jax-0.3.23_relax-testPoly5-tolerance.patch',
'jax-0.4.4_cusparse.patch',
],
'runtest': "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_ALLOCATOR=platform " +
"JAX_ENABLE_X64=true pytest -vv tests",
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
'checksums': [
{'jax-v0.4.4.tar.gz': '755eb9b12ab4880e78690f28fc7bd2b491be4e551d8b966e6974753c878dd2c0'},
{'jax-0.3.23_relax-testPoly5-tolerance.patch':
'be64bf36dde4884a97b6c8bb22c6b14ab5b24033cd40bfe7ce18363c55c30e87'},
{'jax-0.4.4_cusparse.patch': '7414115533cce9f37c60850c09c69563a0ed6477c73f03c4132b9c2ae75ae60f'},
],
}),
]

sanity_pip_check = True

moduleclass = 'tools'
39 changes: 39 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.4_cusparse.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
Skip tests:
** On entry to cusparseSpMM_bufferSize(): matrix B and C must be in column-major order
cuSparseTest.test_coo_matmat1
cuSparseTest.test_coo_matmat3
cuSparseTest.test_coo_matmat5
test_coo_sorted_indices_gpu_lowerings
** On entry to cusparseSpMM_bufferSize(): opA != CUSPARSE_OPERATION_NON_TRANSPOSE is not supported with CUSPARSE_SPMM_COO_ALG2
cuSparseTest.test_coo_matmat6
cuSparseTest.test_coo_matmat7
** On entry to cusparseSpMM_bufferSize(): CUSPARSE_SPMM_COO_ALG2 does not support 64-bit indices
test_coo_matmat_layout
Patch by Simon Branford (University of Birmingham)
--- tests/sparse_test.py.orig 2023-02-23 08:59:00.381238000 +0000
+++ tests/sparse_test.py 2023-02-23 10:39:59.054404886 +0000
@@ -445,6 +445,8 @@
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

+ @unittest.skip("""1, 3, 5: On entry to cusparseSpMM_bufferSize(): matrix B and C must be in column-major order;
+6, 7: On entry to cusparseSpMM_bufferSize(): opA != CUSPARSE_OPERATION_NON_TRANSPOSE is not supported with CUSPARSE_SPMM_COO_ALG2""")
@jtu.sample_product(
shape=[(5, 8), (8, 5), (5, 5), (8, 8)],
dtype=all_dtypes,
@@ -465,6 +467,7 @@
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

+ @unittest.skip("On entry to cusparseSpMM_bufferSize(): CUSPARSE_SPMM_COO_ALG2 does not support 64-bit indices")
def test_coo_matmat_layout(self):
# Regression test for https://github.com/google/jax/issues/7533
d = jnp.array([1.0, 2.0, 3.0, 4.0])
@@ -494,6 +497,7 @@

@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
+ @unittest.skip("On entry to cusparseSpMM_bufferSize(): matrix B and C must be in column-major order")
def test_coo_sorted_indices_gpu_lowerings(self):
dtype = jnp.float32

0 comments on commit f4ff8b3

Please sign in to comment.