From 2b92b555e551f427b6a6e25507e8e0624dced779 Mon Sep 17 00:00:00 2001 From: Thomas Hirtz Date: Tue, 5 Nov 2024 22:36:11 +0100 Subject: [PATCH 1/2] add dtype --- jumanji/types.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/jumanji/types.py b/jumanji/types.py index ecbb069a7..c104114e2 100644 --- a/jumanji/types.py +++ b/jumanji/types.py @@ -95,6 +95,7 @@ def restart( observation: Observation, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = (), + dtype: Union[jnp.dtype, type] = float, ) -> TimeStep: """Returns a `TimeStep` with `step_type` set to `StepType.FIRST`. @@ -114,8 +115,8 @@ def restart( extras = extras or {} return TimeStep( step_type=StepType.FIRST, - reward=jnp.zeros(shape, dtype=float), - discount=jnp.ones(shape, dtype=float), + reward=jnp.zeros(shape, dtype=dtype), + discount=jnp.ones(shape, dtype=dtype), observation=observation, extras=extras, ) @@ -127,6 +128,7 @@ def transition( discount: Optional[Array] = None, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = (), + dtype: Union[jnp.dtype, type] = float, ) -> TimeStep: """Returns a `TimeStep` with `step_type` set to `StepType.MID`. @@ -145,7 +147,7 @@ def transition( Returns: TimeStep identified as a transition. """ - discount = discount if discount is not None else jnp.ones(shape, dtype=float) + discount = discount if discount is not None else jnp.ones(shape, dtype=dtype) extras = extras or {} return TimeStep( step_type=StepType.MID, @@ -161,6 +163,7 @@ def termination( observation: Observation, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = (), + dtype: Union[jnp.dtype, type] = float, ) -> TimeStep: """Returns a `TimeStep` with `step_type` set to `StepType.LAST`. @@ -182,7 +185,7 @@ def termination( return TimeStep( step_type=StepType.LAST, reward=reward, - discount=jnp.zeros(shape, dtype=float), + discount=jnp.zeros(shape, dtype=dtype), observation=observation, extras=extras, ) @@ -194,6 +197,7 @@ def truncation( discount: Optional[Array] = None, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = (), + dtype: Union[jnp.dtype, type] = float, ) -> TimeStep: """Returns a `TimeStep` with `step_type` set to `StepType.LAST`. @@ -211,7 +215,7 @@ def truncation( Returns: TimeStep identified as the truncation of an episode. """ - discount = discount if discount is not None else jnp.ones(shape, dtype=float) + discount = discount if discount is not None else jnp.ones(shape, dtype=dtype) extras = extras or {} return TimeStep( step_type=StepType.LAST, From 53f0ac9ffda7051841fb5f505113fa4c719a7eae Mon Sep 17 00:00:00 2001 From: Thomas Hirtz Date: Tue, 12 Nov 2024 09:13:10 +0100 Subject: [PATCH 2/2] update docstrings --- jumanji/types.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/jumanji/types.py b/jumanji/types.py index c104114e2..249851846 100644 --- a/jumanji/types.py +++ b/jumanji/types.py @@ -108,6 +108,8 @@ def restart( shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. + dtype: Optional parameter to specify the data type of the rewards and discounts. + Defaults to `float`. Returns: TimeStep identified as a reset. @@ -143,6 +145,8 @@ def transition( shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. + dtype: Optional parameter to specify the data type of the discounts. Defaults + to `float`. Returns: TimeStep identified as a transition. @@ -177,6 +181,8 @@ def termination( shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. + dtype: Optional parameter to specify the data type of the discounts. Defaults + to `float`. Returns: TimeStep identified as the termination of an episode. @@ -212,6 +218,9 @@ def truncation( shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. + dtype: Optional parameter to specify the data type of the discounts. Defaults + to `float`. + Returns: TimeStep identified as the truncation of an episode. """