Skip to content

Commit

Permalink
Merge pull request #14
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve authored Mar 26, 2024
2 parents 1f5d252 + 4261266 commit 4ea70d2
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 6 deletions.
8 changes: 7 additions & 1 deletion src/elisa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
DistParameter as DistParameter,
UniformParameter as UniformParameter,
)
from .util import jax_enable_x64, set_cpu_cores
from .util import (
jax_debug_nans as jax_debug_nans,
jax_enable_x64,
set_cpu_cores,
set_jax_platform,
)

jax_enable_x64(True)
set_jax_platform('cpu')
set_cpu_cores(4)
2 changes: 2 additions & 0 deletions src/elisa/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .config import (
jax_debug_nans as jax_debug_nans,
jax_enable_x64 as jax_enable_x64,
set_cpu_cores as set_cpu_cores,
set_jax_platform as set_jax_platform,
)
74 changes: 69 additions & 5 deletions src/elisa/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@

from __future__ import annotations

import os
import re
import warnings
from multiprocessing import cpu_count
from typing import TYPE_CHECKING

from numpyro import enable_x64, set_host_device_count
import jax

if TYPE_CHECKING:
from typing import Literal


def jax_enable_x64(use_x64: bool) -> None:
Expand All @@ -16,16 +22,54 @@ def jax_enable_x64(use_x64: bool) -> None:
use_x64 : bool
When `True`, JAX arrays will use 64 bits else 32 bits.
"""
enable_x64(bool(use_x64))
if not use_x64:
use_x64 = os.getenv('JAX_ENABLE_X64', 0)
jax.config.update('jax_enable_x64', bool(use_x64))


def set_jax_platform(platform: Literal['cpu', 'gpu', 'tpu'] | None = None):
"""Set JAX platform to CPU, GPU, or TPU.
.. warning::
This utility takes effect only before running any JAX program.
Parameters
----------
platform : {'cpu', 'gpu', 'tpu'}, optional
Either 'cpu', 'gpu', or 'tpu'.
"""
if platform is None:
platform = os.getenv('JAX_PLATFORM_NAME', 'cpu')

assert platform in {'cpu', 'gpu', 'tpu', None}

jax.config.update('jax_platform_name', platform)

if platform == 'gpu':
# see https://jax.readthedocs.io/en/latest/gpu_performance_tips.html
xla_gpu_flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true'
)
xla_flags = os.getenv('XLA_FLAGS', '')
if xla_gpu_flags not in xla_flags:
os.environ['XLA_FLAGS'] = f'{xla_flags} {xla_gpu_flags}'


def set_cpu_cores(n: int) -> None:
"""Set CPU number to use, should be called before running JAX codes.
"""Set device number to use in JAX.
.. warning::
This utility takes effect only for CPU platform and before running any
JAX program.
Parameters
----------
n : int
CPU number to use.
Device number to use.
"""
n = int(n)
total_cores = cpu_count()
Expand All @@ -36,4 +80,24 @@ def set_cpu_cores(n: int) -> None:
warnings.warn(msg, Warning)
n = total_cores - 1

set_host_device_count(n)
xla_flags = os.getenv('XLA_FLAGS', '')
xla_flags = re.sub(
r'--xla_force_host_platform_device_count=\S+', '', xla_flags
).split()
os.environ['XLA_FLAGS'] = ' '.join(
[f'--xla_force_host_platform_device_count={n}'] + xla_flags
)


def jax_debug_nans(flag: bool):
"""Automatically detect when NaNs are produced when running JAX codes.
See JAX `docs <https://jax.readthedocs.io/en/latest/debugging/flags.html>`_
for details.
Parameters
----------
flag : bool
When `True`, raises an error when NaNs is detected.
"""
jax.config.update('jax_debug_nans', bool(flag))

0 comments on commit 4ea70d2

Please sign in to comment.