Skip to content

Commit

Permalink
Merge pull request #6581 from hawkinsp:libdevice
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 371151412
  • Loading branch information
jax authors committed Apr 29, 2021
2 parents 23cbcbe + c983d3c commit 6826021
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 2 deletions.
1 change: 1 addition & 0 deletions build/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ py_binary(
"//jaxlib:cusparse_kernels",
"//jaxlib:cuda_lu_pivot_kernels",
"//jaxlib:cuda_prng_kernels",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
"//jaxlib:rocblas_kernels",
]),
Expand Down
4 changes: 4 additions & 0 deletions build/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def prepare_wheel(sources_path):
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_lu_pivot_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.pyd"))
if r.Rlocation("__main__/jaxlib/cusolver.py") is not None:
libdevice_dir = os.path.join(jaxlib_dir, "cuda", "nvvm", "libdevice")
os.makedirs(libdevice_dir)
copy_file(r.Rlocation("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"),
dst_dir=libdevice_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_linalg.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng.py"))
Expand Down
9 changes: 9 additions & 0 deletions jax/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# This module is largely a wrapper around `jaxlib` that performs version
# checking on import.

import os
from typing import Optional

__all__ = [
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
'pocketfft', 'pytree', 'tpu_client', 'version', 'xla_client'
Expand Down Expand Up @@ -99,3 +102,9 @@ def _check_jaxlib_version():
from jaxlib import tpu_client # pytype: disable=import-error
except:
tpu_client = None


cuda_path: Optional[str]
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
if not os.path.isdir(cuda_path):
cuda_path = None
7 changes: 6 additions & 1 deletion jax/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# Disable "WARNING: Logging before flag parsing goes to stderr." message
logging._warn_preinit_stderr = 0

import jax.lib
from .._src.config import flags
from jax._src import util, traceback_util
from jax._src import dtypes
Expand Down Expand Up @@ -113,8 +114,12 @@ def get_compile_options(
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment

debug_options = compile_options.executable_build_options.debug_options
if jax.lib.cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = jax.lib.cuda_path

if FLAGS.jax_disable_most_optimizations:
debug_options = compile_options.executable_build_options.debug_options

debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
url='https://github.com/google/jax',
license='Apache-2.0',
package_data={
'jaxlib': ['*.so', '*.pyd*', 'py.typed'],
'jaxlib': ['*.so', '*.pyd*', 'py.typed', 'cuda/nvvm/libdevice/libdevice*'],
'jaxlib.xla_extension-stubs': ['*.pyi'],
},
zip_safe=False,
Expand Down

0 comments on commit 6826021

Please sign in to comment.