Skip to content

Commit

Permalink
Reduce number of ops sample_halton_sequence adds to the graph.
Browse files Browse the repository at this point in the history
    - Turn loop over sampling and argmax in to a single sample and single argmax.

PiperOrigin-RevId: 566408146
  • Loading branch information
srvasude authored and tensorflower-gardener committed Sep 18, 2023
1 parent 200e753 commit a204ec8
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions tensorflow_probability/python/mcmc/sample_halton_sequence_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,23 +276,28 @@ def _get_permutations(num_results, dims, seed=None):
Args:
num_results: A positive scalar `Tensor` of integral type. The number of
draws from the discrete uniform distribution over the permutation groups.
dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the
dims: A 1D numpy array of the same dtype as `num_results`. The degree of the
permutation groups from which to sample.
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
Returns:
permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same
dtype as `dims`.
"""
seeds = samplers.split_seed(seed, n=ps.size(dims))

def generate_one(dim, seed):
return tf.argsort(samplers.uniform(
[num_results, dim], seed=seed), axis=-1)

return tf.concat([generate_one(dim, seed)
for dim, seed in zip(dims, tf.unstack(seeds))],
axis=-1)
n = dims.size
max_size = np.max(dims)
samples = samplers.uniform([num_results, n, max_size], seed=seed)
should_mask = np.arange(max_size) >= dims[..., np.newaxis]
# Choose a number that does not affect the permutation and relative location.
samples = tf.where(
should_mask,
dtype_util.as_numpy_dtype(samples.dtype)(np.arange(max_size) + 10.),
samples)
samples = tf.argsort(samples, axis=-1)
# Generate the set of indices to gather.
should_mask = np.tile(should_mask, [num_results, 1, 1])
indices = np.stack(np.where(~should_mask), axis=-1)
return tf.gather_nd(samples, indices)


def _get_indices(num_results, sequence_indices, dtype, name=None):
Expand Down

0 comments on commit a204ec8

Please sign in to comment.