Skip to content

Commit

Permalink
feat: explicit version number and testing (#39)
Browse files Browse the repository at this point in the history
alaterre authored Nov 7, 2022
1 parent ff6fd78 commit b1015e6
Showing 3 changed files with 126 additions and 20 deletions.
6 changes: 2 additions & 4 deletions docs/guides/registration.md
Original file line number Diff line number Diff line change
@@ -40,17 +40,15 @@ you can register it as follows:
from jumanji import register

register(
id="CustomEnv-v0", # format: (env_name)[-v(version)]
id="CustomEnv-v0", # format: (env_name)-v(version)
entry_point="path.to.your.package:CustomEnv", # class constructor
kwargs={...}, # environment configuration
)
```

To successfully register your environment, make sure to provide the right path to your class constructor.
The `kwargs` argument is there to configurate the environment and allow you to register scenarios with a specific set of arguments.
The environment ID must respect the format `(env_name)[-v(version)]`.
The version number is optional as Jumanji automatically appends `v0` to it if omitted.
In that case, the environment can be retrieved with or without the `-v0` suffix.
The environment ID must respect the format `(env_name)-v(version)`, where the version number starts at `v0`.

For examples on how to register environments, please see our `__init__.py` file.

40 changes: 24 additions & 16 deletions jumanji/registration.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
import importlib
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, Set, Tuple

from jumanji.env import Environment

@@ -25,7 +25,7 @@
def parse_env_id(id: str) -> Tuple[str, int]:
"""Parse an environment name.
The format must obey the following structure: {env-name}-v{version-number}.
The format must obey the following structure: <env-name>-v<version>.
Args:
id: The environment ID to parse.
@@ -40,19 +40,23 @@ def parse_env_id(id: str) -> Tuple[str, int]:
if not match:
raise ValueError(
f"Malformed environment name: {id}."
f"All ID's must be of the form (env-name)[-v(version-number)]."
"All env ID's must be of the form <env-name>-v<version>."
)

name, version = match.group("name", "version")

# default the version to zero if not provided
version = int(version) if version is not None else 0
# missing version number
if version is None:
raise ValueError(
f"Version missing, got name={name} and version={version}. "
"All env ID's must be of the form <env-name>-v<version>."
)

return name, version
return name, int(version)


def get_env_id(name: str, version: Optional[int] = None) -> str:
"""Get the full env ID given a name and (optional) version.
def get_env_id(name: str, version: int) -> str:
"""Get the full env ID given a name and version.
Args:
name: The environment name.
@@ -61,10 +65,7 @@ def get_env_id(name: str, version: Optional[int] = None) -> str:
Returns:
The environment ID.
"""
version = version or 0
full_name = name + f"-v{version}"

return full_name
return name + f"-v{version}"


@dataclass
@@ -103,17 +104,24 @@ def _check_registration_is_allowed(spec: EnvSpec) -> None:
if spec.id in _REGISTRY:
raise ValueError(f"Trying to override the registered environment {spec.id}.")

# Verify that version v-1 exist when trying to register version v (except 0)
latest_version = max(
(_spec.version for _spec in _REGISTRY.values() if _spec.name == spec.name),
default=None, # if no version of the environment is registered
)

# the first version of an env must be zero.
if (latest_version is None) and spec.version != 0:
raise ValueError(
f"The first version of an unregistered environment must be 0, got {spec.version}"
)

# Verify that version v-1 exists when trying to register version v (except 0)
if (latest_version is not None) and latest_version != (spec.version - 1):
raise ValueError(
f"Trying to register version {spec.version} of {spec.name}. "
f"However, the latest registered version of {spec.name} is {latest_version}."
)


def register(
id: str,
@@ -123,7 +131,7 @@ def register(
"""Register an environment.
Args:
id: environment ID, formatted as `(env_name)[-v(version)]`.
id: environment ID, formatted as `<env-name>-v<version>`.
entry_point: module and class constructor for the environment.
**kwargs: extra arguments that will be passed to the environment constructor at
instantiation.
@@ -185,5 +193,5 @@ def make(id: str, *args: Any, **kwargs: Any) -> Environment:
return env_fn(*args, **env_fn_kwargs)


def registered_environments() -> List[str]:
return list(_REGISTRY.keys())
def registered_environments() -> Set[str]:
return set(_REGISTRY.keys())
100 changes: 100 additions & 0 deletions jumanji/registration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import pytest

from jumanji import registration
from jumanji.testing.fakes import FakeEnvironment


@pytest.fixture(autouse=True)
def mock_global_registry(mocker): # type: ignore
mocker.patch("jumanji.registration._REGISTRY", {})
return mocker


class TestParser:
@pytest.mark.parametrize("env_id", ("Env", "Env_v0"))
def test_parser__wrong_version(self, env_id: str) -> None:
with pytest.raises(ValueError):
registration.parse_env_id(env_id)

@pytest.mark.parametrize(
"env_id, expected",
[("Env-v0", ("Env", 0)), ("Env-test-v10", ("Env-test", 10))],
)
def test_parser__name_version(self, env_id: str, expected: Tuple[str, int]) -> None:
assert registration.parse_env_id(env_id) == expected


class TestRegistrationRules:
def test_registration__first_version(self) -> None:
# Env-v0 must exist to register v1
env_spec = registration.EnvSpec(id="Env-v1", entry_point="")
with pytest.raises(ValueError, match="first version"):
registration._check_registration_is_allowed(env_spec)

# the first version must be zero
env_spec = registration.EnvSpec(id="Env-v0", entry_point="")
registration._check_registration_is_allowed(env_spec)

def test_registration__next_version(self) -> None:
# check that the next registrable version is v+1
registration.register("Env-v0", entry_point="")

env_spec = registration.EnvSpec(id="Env-v2", entry_point="")
with pytest.raises(ValueError):
registration._check_registration_is_allowed(env_spec)

env_spec = registration.EnvSpec(id="Env-v1", entry_point="")
registration._check_registration_is_allowed(env_spec)

def test_registration__already_registered(self) -> None:
env_spec = registration.EnvSpec(id="Env-v0", entry_point="")
registration.register(env_spec.id, entry_point=env_spec.entry_point)
with pytest.raises(ValueError, match="override the registered environment"):
registration._check_registration_is_allowed(env_spec)


def test_register() -> None:
env_ids = ("Cat-v0", "Dog-v0", "Fish-v0", "Cat-v1")
for env_id in env_ids:
registration.register(env_id, entry_point="")
registered_envs = registration.registered_environments()
assert all(env_id in registered_envs for env_id in env_ids)


def test_register__instantiate_registered_env() -> None:
env_id = "Fake-v0"
registration.register(
id=env_id,
entry_point="jumanji.testing.fakes:FakeEnvironment",
)
env = registration.make(env_id)
assert isinstance(env, FakeEnvironment)


def test_register__override_kwargs() -> None:
env_id = "Fake-v0"
obs_shape = (11, 17)
registration.register(
id=env_id,
entry_point="jumanji.testing.fakes:FakeEnvironment",
)
env: FakeEnvironment = registration.make( # type: ignore
env_id, observation_shape=obs_shape
)
assert env.observation_spec().shape == obs_shape

0 comments on commit b1015e6

Please sign in to comment.