Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype choice in step type/functions #262

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def restart(
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
sash-a marked this conversation as resolved.
Show resolved Hide resolved
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.FIRST`.

Expand All @@ -107,15 +108,17 @@ def restart(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the rewards and discounts.
Defaults to `float`.

Returns:
TimeStep identified as a reset.
"""
extras = extras or {}
return TimeStep(
step_type=StepType.FIRST,
reward=jnp.zeros(shape, dtype=float),
discount=jnp.ones(shape, dtype=float),
reward=jnp.zeros(shape, dtype=dtype),
discount=jnp.ones(shape, dtype=dtype),
observation=observation,
extras=extras,
)
Expand All @@ -127,6 +130,7 @@ def transition(
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.MID`.

Expand All @@ -141,11 +145,13 @@ def transition(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.

Returns:
TimeStep identified as a transition.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
extras = extras or {}
return TimeStep(
step_type=StepType.MID,
Expand All @@ -161,6 +167,7 @@ def termination(
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`.

Expand All @@ -174,6 +181,8 @@ def termination(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.

Returns:
TimeStep identified as the termination of an episode.
Expand All @@ -182,7 +191,7 @@ def termination(
return TimeStep(
step_type=StepType.LAST,
reward=reward,
discount=jnp.zeros(shape, dtype=float),
discount=jnp.zeros(shape, dtype=dtype),
observation=observation,
extras=extras,
)
Expand All @@ -194,6 +203,7 @@ def truncation(
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`.

Expand All @@ -208,10 +218,13 @@ def truncation(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.

Returns:
TimeStep identified as the truncation of an episode.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
extras = extras or {}
return TimeStep(
step_type=StepType.LAST,
Expand Down