Skip to content

Commit

Permalink
[RLlib] Fix connector registry (ray-project#30635)
Browse files Browse the repository at this point in the history
Introduce API to force re-register all the connectors in case the global registry is lost.

Signed-off-by: Jun Gong <jungong@anyscale.com>
Signed-off-by: tmynn <hovhannes.tamoyan@gmail.com>
  • Loading branch information
kouroshHakha authored and tamohannes committed Jan 25, 2023
1 parent c6e2556 commit f72fc3f
Show file tree
Hide file tree
Showing 18 changed files with 101 additions and 61 deletions.
3 changes: 3 additions & 0 deletions rllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ def _setup_logger():
def _register_all():
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.registry import ALGORITHMS, _get_algorithm_class
from ray.rllib.connectors.registry import _register_all_connectors
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS

_register_all_connectors()

for key, get_trainable_class_and_config in list(ALGORITHMS.items()) + list(
CONTRIBUTED_ALGORITHMS.items()
):
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/action/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from ray.rllib.connectors.connector import (
ActionConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.utils.spaces.space_utils import clip_action, get_base_struct_from_space
from ray.rllib.utils.typing import ActionConnectorDataType
from ray.util.annotations import PublicAPI
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/action/immutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ray.rllib.connectors.connector import (
ActionConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.utils.numpy import make_action_immutable
from ray.rllib.utils.typing import ActionConnectorDataType
from ray.util.annotations import PublicAPI
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/action/lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from ray.rllib.connectors.connector import (
ActionConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import (
ActionConnectorDataType,
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/action/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from ray.rllib.connectors.connector import (
ActionConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.utils.spaces.space_utils import (
get_base_struct_from_space,
unsquash_action,
Expand Down
5 changes: 2 additions & 3 deletions rllib/connectors/action/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
Connector,
ConnectorContext,
ConnectorPipeline,
get_connector,
register_connector,
)
from ray.rllib.connectors.registry import get_connector, register_connector
from ray.rllib.utils.typing import ActionConnectorDataType
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -47,7 +46,7 @@ def from_state(ctx: ConnectorContext, params: Any):
for state in params:
try:
name, subparams = state
connectors.append(get_connector(ctx, name, subparams))
connectors.append(get_connector(name, ctx, subparams))
except Exception as e:
logger.error(f"Failed to de-serialize connector state: {state}")
raise e
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/clip_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.util.annotations import PublicAPI
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import (
AgentConnectorDataType,
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/mean_std_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from ray.rllib.connectors.connector import AgentConnector
from ray.rllib.connectors.connector import (
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.filter import MeanStdFilter, ConcurrentMeanStdFilter
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/obs_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.models.preprocessors import get_preprocessor, NoPreprocessor
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentConnectorDataType
Expand Down
5 changes: 2 additions & 3 deletions rllib/connectors/agent/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
Connector,
ConnectorContext,
ConnectorPipeline,
get_connector,
register_connector,
)
from ray.rllib.connectors.registry import get_connector, register_connector
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -58,7 +57,7 @@ def from_state(ctx: ConnectorContext, params: List[Any]):
for state in params:
try:
name, subparams = state
connectors.append(get_connector(ctx, name, subparams))
connectors.append(get_connector(name, ctx, subparams))
except Exception as e:
logger.error(f"Failed to de-serialize connector state: {state}")
raise e
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/state_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/view_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import (
AgentConnectorDataType,
Expand Down
31 changes: 0 additions & 31 deletions rllib/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
AlgorithmConfigDict,
TensorType,
)
from ray.tune.registry import RLLIB_CONNECTOR, _global_registry
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -461,33 +460,3 @@ def __getitem__(self, key: Union[str, int, type]):
results.append(c)

return results


@PublicAPI(stability="alpha")
def register_connector(name: str, cls: Connector):
"""Register a connector for use with RLlib.
Args:
name: Name to register.
cls: Callable that creates an env.
"""
if not issubclass(cls, Connector):
raise TypeError("Can only register Connector type.", cls)
_global_registry.register(RLLIB_CONNECTOR, name, cls)


@PublicAPI(stability="alpha")
def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Connector:
"""Get a connector by its name and serialized config.
Args:
name: name of the connector.
params: serialized parameters of the connector.
Returns:
Constructed connector.
"""
if not _global_registry.contains(RLLIB_CONNECTOR, name):
raise NameError("connector not found.", name)
cls = _global_registry.get(RLLIB_CONNECTOR, name)
return cls.from_state(ctx, params)
67 changes: 67 additions & 0 deletions rllib/connectors/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Registry of connector names for global access."""

import importlib
import sys

from typing import Any
from ray.tune.registry import RLLIB_CONNECTOR, _global_registry

from ray.util.annotations import PublicAPI
from ray.rllib.connectors.connector import Connector, ConnectorContext


ALL_CONNECTORS = set()


@PublicAPI(stability="alpha")
def register_connector(name: str, cls: Connector):
"""Register a connector for use with RLlib.
Args:
name: Name to register.
cls: Callable that creates an env.
"""
if _global_registry.contains(RLLIB_CONNECTOR, name):
return

if not issubclass(cls, Connector):
raise TypeError("Can only register Connector type.", cls)

# Record it in local registry in case we need to register everything
# again in the global registry, for example in the event of cluster
# restarts.
ALL_CONNECTORS.add(cls.__module__)

# Register in global registry.
_global_registry.register(RLLIB_CONNECTOR, name, cls)


@PublicAPI(stability="alpha")
def get_connector(name: str, ctx: ConnectorContext, params: Any = None) -> Connector:
# TODO(jungong) : switch the order of parameters man!!
"""Get a connector by its name and serialized config.
Args:
name: name of the connector.
ctx: Connector context.
params: serialized parameters of the connector.
Returns:
Constructed connector.
"""
if not _global_registry.contains(RLLIB_CONNECTOR, name):
raise NameError("connector not found.", name)
return _global_registry.get(RLLIB_CONNECTOR, name).from_state(ctx, params)


def _register_all_connectors():
"""Force register all connectors again.
In case the cluster has been restarted, and the global registry
has to be rebuilt.
"""
for module in ALL_CONNECTORS:
if module in sys.modules:
importlib.reload(sys.modules[module])
else:
importlib.import_module(module)
13 changes: 7 additions & 6 deletions rllib/connectors/tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
from ray.rllib.connectors.connector import ConnectorContext, get_connector
from ray.rllib.connectors.connector import ConnectorContext
from ray.rllib.connectors.registry import get_connector
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ActionConnectorDataType

Expand All @@ -21,7 +22,7 @@ def test_connector_pipeline(self):
connectors = [ConvertToNumpyConnector(ctx)]
pipeline = ActionConnectorPipeline(ctx, connectors)
name, params = pipeline.to_state()
restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))

Expand All @@ -33,7 +34,7 @@ def test_convert_to_numpy_connector(self):

self.assertEqual(name, "ConvertToNumpyConnector")

restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, ConvertToNumpyConnector))

action = torch.Tensor([8, 9])
Expand All @@ -53,7 +54,7 @@ def test_normalize_action_connector(self):
name, params = c.to_state()
self.assertEqual(name, "NormalizeActionsConnector")

restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, NormalizeActionsConnector))

ac_data = ActionConnectorDataType(0, 1, {}, (0.5, [], {}))
Expand All @@ -70,7 +71,7 @@ def test_clip_action_connector(self):
name, params = c.to_state()
self.assertEqual(name, "ClipActionsConnector")

restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, ClipActionsConnector))

ac_data = ActionConnectorDataType(0, 1, {}, (8.8, [], {}))
Expand All @@ -87,7 +88,7 @@ def test_immutable_action_connector(self):
name, params = c.to_state()
self.assertEqual(name, "ImmutableActionsConnector")

restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, ImmutableActionsConnector))

ac_data = ActionConnectorDataType(0, 1, {}, (np.array([8.8]), [], {}))
Expand Down
13 changes: 7 additions & 6 deletions rllib/connectors/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import ConnectorContext, get_connector
from ray.rllib.connectors.connector import ConnectorContext
from ray.rllib.connectors.registry import get_connector
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import check
Expand All @@ -32,7 +33,7 @@ def test_connector_pipeline(self):
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
pipeline = AgentConnectorPipeline(ctx, connectors)
name, params = pipeline.to_state()
restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))

Expand All @@ -50,7 +51,7 @@ def test_obs_preprocessor_connector(self):
c = ObsPreprocessorConnector(ctx)
name, params = c.to_state()

restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))

obs = obs_space.sample()
Expand Down Expand Up @@ -81,7 +82,7 @@ def test_clip_reward_connector(self):
self.assertEqual(name, "ClipRewardAgentConnector")
self.assertAlmostEqual(params["limit"], 2.0)

restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, ClipRewardAgentConnector))

d = AgentConnectorDataType(
Expand All @@ -102,7 +103,7 @@ def test_flatten_data_connector(self):
c = FlattenDataAgentConnector(ctx)

name, params = c.to_state()
restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))

sample_batch = {
Expand Down Expand Up @@ -494,7 +495,7 @@ def test_connector_pipline_with_view_requirement(self):
agent_connector = AgentConnectorPipeline(ctx, connectors)

name, params = agent_connector.to_state()
restored = get_connector(ctx, name, params)
restored = get_connector(name, ctx, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
for cidx, c in enumerate(connectors):
check(restored.connectors[cidx].to_state(), c.to_state())
Expand Down
5 changes: 3 additions & 2 deletions rllib/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import Connector, ConnectorContext, get_connector
from ray.rllib.connectors.connector import Connector, ConnectorContext
from ray.rllib.connectors.registry import get_connector
from ray.rllib.connectors.agent.mean_std_filter import (
MeanStdObservationFilterAgentConnector,
ConcurrentMeanStdObservationFilterAgentConnector,
Expand Down Expand Up @@ -117,7 +118,7 @@ def restore_connectors_for_policy(
"""
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
name, params = connector_config
return get_connector(ctx, name, params)
return get_connector(name, ctx, params)


# We need this filter selection mechanism temporarily to remain compatible to old API
Expand Down

0 comments on commit f72fc3f

Please sign in to comment.