Skip to content

Commit

Permalink
fix(parametric_action_distribution): sum kl divergence over event_ndi…
Browse files Browse the repository at this point in the history
…ms in parametric action distribution (#142)
  • Loading branch information
clement-bonnet authored May 26, 2023
1 parent 87ebb7f commit 84957a6
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions jumanji/training/networks/parametric_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def __init__(
"""
self._param_size = param_size
self._postprocessor = postprocessor
if event_ndims not in [0, 1]:
raise ValueError(
f"Event ndims {event_ndims} is not supported, expected value in [0, 1]."
)
self._event_ndims = event_ndims # rank of events
assert event_ndims in [0, 1]

@abc.abstractmethod
def create_dist(self, parameters: chex.Array) -> Distribution:
Expand Down Expand Up @@ -91,8 +94,6 @@ def log_prob(self, parameters: chex.Array, raw_actions: chex.Array) -> chex.Arra
log_probs -= self._postprocessor.forward_log_det_jacobian(raw_actions)
if self._event_ndims == 1:
log_probs = jnp.sum(log_probs, axis=-1) # sum over action dimension
else:
assert self._event_ndims == 0
return log_probs

def entropy(self, parameters: chex.Array, seed: chex.PRNGKey) -> chex.Array:
Expand All @@ -102,17 +103,23 @@ def entropy(self, parameters: chex.Array, seed: chex.PRNGKey) -> chex.Array:
entropy += self._postprocessor.forward_log_det_jacobian(dist.sample(seed=seed))
if self._event_ndims == 1:
entropy = jnp.sum(entropy, axis=-1)
else:
assert self._event_ndims == 0
return entropy

def kl_divergence(
self, parameters: chex.Array, other_parameters: chex.Array
) -> chex.Array:
"""KL divergence is invariant with respect to transformation by the same bijector."""
if not isinstance(self._postprocessor, IdentityBijector):
raise ValueError(
f"The current post_processor used ({self._postprocessor}) is a non-identity"
"bijector which does not implement kl_divergence."
)
dist = self.create_dist(parameters)
other_dist = self.create_dist(other_parameters)
return dist.kl_divergence(other_dist)
kl_divergence = dist.kl_divergence(other_dist)
if self._event_ndims == 1:
kl_divergence = jnp.sum(kl_divergence, axis=-1)
return kl_divergence


class CategoricalParametricDistribution(ParametricDistribution):
Expand Down

0 comments on commit 84957a6

Please sign in to comment.