Skip to content

Commit

Permalink
Migrate TFP to support JAX typed PRNG keys
Browse files Browse the repository at this point in the history
The context is described more fully in [JEP 9263](jax-ml/jax#17297).
If you have comments on the JEP, we'd love to hear them!

PiperOrigin-RevId: 566664782
  • Loading branch information
vanderplas authored and tensorflower-gardener committed Sep 19, 2023
1 parent a204ec8 commit 6efcda9
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 20 deletions.
18 changes: 9 additions & 9 deletions discussion/adaptive_malt/adaptive_malt.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def adaptive_mcmc_step(
target_log_prob_fn: fun_mc.PotentialFn,
num_mala_steps: int,
num_adaptation_steps: int,
seed: jax.random.KeyArray,
seed: jax.Array,
method: str = 'hmc',
damping: Optional[jnp.ndarray] = None,
scalar_step_size: Optional[jnp.ndarray] = None,
Expand Down Expand Up @@ -778,7 +778,7 @@ def adaptive_nuts_step(
target_log_prob_fn: fun_mc.PotentialFn,
num_mala_steps: int,
num_adaptation_steps: int,
seed: jax.random.KeyArray,
seed: jax.Array,
scalar_step_size: Optional[jnp.ndarray] = None,
vector_step_size: Optional[jnp.ndarray] = None,
rvar_factor: int = 8,
Expand Down Expand Up @@ -1040,7 +1040,7 @@ class MeadsExtra(NamedTuple):


def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn,
num_folds: int, seed: jax.random.KeyArray):
num_folds: int, seed: jax.Array):
"""Initializes MEADS."""
num_dimensions = state.shape[-1]
num_chains = state.shape[0]
Expand All @@ -1062,7 +1062,7 @@ def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn,

def meads_step(meads_state: MeadsState,
target_log_prob_fn: fun_mc.PotentialFn,
seed: jax.random.KeyArray,
seed: jax.Array,
vector_step_size: Optional[jnp.ndarray] = None,
damping: Optional[jnp.ndarray] = None,
step_size_multiplier: float = 0.5,
Expand Down Expand Up @@ -1221,7 +1221,7 @@ def run_adaptive_mcmc_on_target(
init_step_size: jnp.ndarray,
num_adaptation_steps: int,
num_results: int,
seed: jax.random.KeyArray,
seed: jax.Array,
num_mala_steps: int = 100,
rvar_smoothing: int = 0,
trajectory_opt_kwargs: Mapping[str, Any] = immutabledict.immutabledict({
Expand Down Expand Up @@ -1358,7 +1358,7 @@ def run_adaptive_nuts_on_target(
init_step_size: jnp.ndarray,
num_adaptation_steps: int,
num_results: int,
seed: jax.random.KeyArray,
seed: jax.Array,
num_mala_steps: int = 100,
rvar_smoothing: int = 0,
num_chains: Optional[int] = None,
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def run_meads_on_target(
num_adaptation_steps: int,
num_results: int,
thinning: int,
seed: jax.random.KeyArray,
seed: jax.Array,
num_folds: int,
num_chains: Optional[int] = None,
init_x: Optional[jnp.ndarray] = None,
Expand Down Expand Up @@ -1596,7 +1596,7 @@ def run_fixed_mcmc_on_target(
target: gym.targets.Model,
init_x: jnp.ndarray,
method: str,
seed: jax.random.KeyArray,
seed: jax.Array,
num_warmup_steps: int,
num_results: int,
scalar_step_size: jnp.ndarray,
Expand Down Expand Up @@ -1706,7 +1706,7 @@ def run_vi_on_target(
init_x: jnp.ndarray,
num_steps: int,
learning_rate: float,
seed: jax.random.KeyArray,
seed: jax.Array,
):
"""Run VI on a target.
Expand Down
4 changes: 3 additions & 1 deletion spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def make_tensor_seed(seed):
"""Converts a seed to a `Tensor` seed."""
if seed is None:
raise ValueError('seed must not be None when using JAX')
if isinstance(seed, jax.random.PRNGKeyArray):
if hasattr(seed, 'dtype') and jax.dtypes.issubdtype(
seed.dtype, jax.dtypes.prng_key
):
return seed
return jnp.asarray(seed, jnp.uint32)

Expand Down
8 changes: 6 additions & 2 deletions tensorflow_probability/python/internal/backend/numpy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,14 @@ def _default_convert_to_tensor(value, dtype=None):
"""Default tensor conversion function for array, bool, int, float, and complex."""
if JAX_MODE:
# TODO(b/223267515): We shouldn't need to specialize here.
if 'PRNGKeyArray' in str(type(value)):
if hasattr(value, 'dtype') and jax.dtypes.issubdtype(
value.dtype, jax.dtypes.prng_key
):
return value
if isinstance(value, (list, tuple)) and value:
if 'PRNGKeyArray' in str(type(value[0])):
if hasattr(value[0], 'dtype') and jax.dtypes.issubdtype(
value[0].dtype, jax.dtypes.prng_key
):
return np.stack(value, axis=0)

inferred_dtype = _infer_dtype(value, np.float32)
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/internal/loop_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def _convert_variables_to_tensors(values):

def tensor_array_from_element(elem, size=None, **kwargs):
"""Construct a tf.TensorArray of elements with the dtype + shape of `elem`."""
if JAX_MODE and isinstance(elem, jax.random.PRNGKeyArray):
# If `trace_elt` is a `PRNGKeyArray`, then then it is not possible to create
if JAX_MODE and jax.dtypes.issubdtype(elem.dtype, jax.dtypes.prng_key):
# If `trace_elt` is a typed prng key, then then it is not possible to create
# a matching (i.e., with the same custom PRNG) instance/array inside
# `TensorArray.__init__` given just a `dtype`, `size`, and `shape`.
#
Expand Down
20 changes: 14 additions & 6 deletions tensorflow_probability/python/internal/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@ def evaluate(self, x):
def _evaluate(x):
if x is None:
return x
# TODO(b/223267515): Improve handling of JAX PRNGKeyArray objects.
if JAX_MODE and isinstance(x, jax.random.PRNGKeyArray):
# TODO(b/223267515): Improve handling of JAX typed PRNG keys.
if (
JAX_MODE
and hasattr(x, 'dtype')
and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key)
):
return x
return np.array(x)
return tf.nest.map_structure(_evaluate, x, expand_composites=True)
Expand All @@ -177,11 +181,15 @@ def _GetNdArray(self, a):
def _evaluateTensors(self, a, b):
if JAX_MODE:
import jax # pylint: disable=g-import-not-at-top
# HACK: In assertions (like self.assertAllClose), convert PRNGKeyArrays
# to "normal" arrays so they can be compared with our existing machinery.
if isinstance(a, jax.random.PRNGKeyArray):
# HACK: In assertions (like self.assertAllClose), convert typed PRNG keys
# to raw arrays so they can be compared with our existing machinery.
if hasattr(a, 'dtype') and jax.dtypes.issubdtype(
a.dtype, jax.dtypes.prng_key
):
a = jax.random.key_data(a)
if isinstance(b, jax.random.PRNGKeyArray):
if hasattr(b, 'dtype') and jax.dtypes.issubdtype(
b.dtype, jax.dtypes.prng_key
):
b = jax.random.key_data(b)
if tf.is_tensor(a) and tf.is_tensor(b):
(a, b) = self.evaluate([a, b])
Expand Down

0 comments on commit 6efcda9

Please sign in to comment.