diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 153ebc1e66c6..66f521ad2a72 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -35,7 +35,7 @@ from ray.rllib.offline import InputReader from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_map import PolicyMap -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch, concat_samples from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE @@ -91,10 +91,9 @@ class SamplerInput(InputReader, metaclass=ABCMeta): def next(self) -> SampleBatchType: batches = [self.get_data()] batches.extend(self.get_extra_batches()) - if len(batches) > 1: - return batches[0].concat_samples(batches) - else: - return batches[0] + if len(batches) == 0: + raise RuntimeError("No data available from sampler.") + return concat_samples(batches) @abstractmethod @DeveloperAPI diff --git a/rllib/execution/buffers/mixin_replay_buffer.py b/rllib/execution/buffers/mixin_replay_buffer.py index 45c0b60cc309..bf23abdf6c10 100644 --- a/rllib/execution/buffers/mixin_replay_buffer.py +++ b/rllib/execution/buffers/mixin_replay_buffer.py @@ -5,7 +5,7 @@ from ray.util.timer import _Timer from ray.rllib.execution.replay_ops import SimpleReplayBuffer -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, concat_samples from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES @@ -155,7 +155,7 @@ def replay( # No replay desired -> Return here. if self.replay_ratio == 0.0: - return SampleBatch.concat_samples(output_batches) + return concat_samples(output_batches) # Only replay desired -> Return a (replayed) sample from the # buffer. elif self.replay_ratio == 1.0: @@ -168,7 +168,7 @@ def replay( while random.random() < num_new * replay_proportion: replay_proportion -= 1 output_batches.append(buffer.replay()) - return SampleBatch.concat_samples(output_batches) + return concat_samples(output_batches) def get_host(self) -> str: """Returns the computer's network name. diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 358d08b4f7d9..03abce007ca2 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -348,7 +348,7 @@ def concat(self, other: "SampleBatch") -> "SampleBatch": >>> print(b1.concat(b2)) # doctest: +SKIP {"a": np.array([1, 2, 3, 4, 5])} """ - return self.concat_samples([self, other]) + return concat_samples([self, other]) @PublicAPI def copy(self, shallow: bool = False) -> "SampleBatch": @@ -1376,7 +1376,7 @@ def wrap_as_needed( @staticmethod @PublicAPI - @Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=False) + @Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True) def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch": return concat_samples_into_ma_batch(samples) diff --git a/rllib/policy/tests/test_sample_batch.py b/rllib/policy/tests/test_sample_batch.py index 6c59316697b6..e82ebb48f46b 100644 --- a/rllib/policy/tests/test_sample_batch.py +++ b/rllib/policy/tests/test_sample_batch.py @@ -9,7 +9,11 @@ import ray from ray.rllib.models.repeated_values import RepeatedValues -from ray.rllib.policy.sample_batch import SampleBatch, attempt_count_timesteps +from ray.rllib.policy.sample_batch import ( + SampleBatch, + attempt_count_timesteps, + concat_samples, +) from ray.rllib.utils.compression import is_compressed from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.test_utils import check @@ -84,7 +88,7 @@ def test_right_zero_padding(self): ) def test_concat(self): - """Tests, SampleBatches.concat() and ...concat_samples().""" + """Tests, SampleBatches.concat() and concat_samples().""" s1 = SampleBatch( { "a": np.array([1, 2, 3]), @@ -97,7 +101,7 @@ def test_concat(self): "b": {"c": np.array([5, 6, 7])}, } ) - concatd = SampleBatch.concat_samples([s1, s2]) + concatd = concat_samples([s1, s2]) check(concatd["a"], [1, 2, 3, 2, 3, 4]) check(concatd["b"]["c"], [4, 5, 6, 5, 6, 7]) check(next(concatd.rows()), {"a": 1, "b": {"c": 4}}) @@ -129,11 +133,11 @@ def test_concat_max_seq_len(self): } ) - concatd = SampleBatch.concat_samples([s1, s2]) + concatd = concat_samples([s1, s2]) check(concatd.max_seq_len, s2.max_seq_len) with self.assertRaises(ValueError): - SampleBatch.concat_samples([s1, s2, s3]) + concat_samples([s1, s2, s3]) def test_rows(self): s1 = SampleBatch( diff --git a/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py b/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py index 31914d8b2d2b..acb3663270f7 100644 --- a/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py +++ b/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py @@ -8,9 +8,8 @@ from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap from ray.rllib.policy.sample_batch import ( DEFAULT_POLICY_ID, - MultiAgentBatch, SampleBatch, - concat_samples, + concat_samples_into_ma_batch, ) from ray.rllib.utils.annotations import override from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import ( @@ -291,7 +290,7 @@ def round_up_or_down(value, ratio): # No replay desired if self.replay_ratio == 0.0: - return concat_samples(output_batches) + return concat_samples_into_ma_batch(output_batches) # Only replay desired elif self.replay_ratio == 1.0: return _buffer.sample(num_items, **kwargs) @@ -315,7 +314,7 @@ def round_up_or_down(value, ratio): # Depending on the implementation of underlying buffers, samples # might be SampleBatches output_batches = [batch.as_multi_agent() for batch in output_batches] - return MultiAgentBatch.concat_samples(output_batches) + return concat_samples_into_ma_batch(output_batches) def check_buffer_is_ready(_policy_id): if ( @@ -344,7 +343,7 @@ def check_buffer_is_ready(_policy_id): if check_buffer_is_ready(policy_id): samples.append(mix_batches(policy_id).as_multi_agent()) - return MultiAgentBatch.concat_samples(samples) + return concat_samples_into_ma_batch(samples) @DeveloperAPI @override(MultiAgentPrioritizedReplayBuffer) diff --git a/rllib/utils/replay_buffers/replay_buffer.py b/rllib/utils/replay_buffers/replay_buffer.py index 6eaf9c426b72..3001bcd524aa 100644 --- a/rllib/utils/replay_buffers/replay_buffer.py +++ b/rllib/utils/replay_buffers/replay_buffer.py @@ -9,7 +9,7 @@ import ray # noqa F401 import psutil -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch, concat_samples from ray.rllib.utils.actor_manager import FaultAwareApply from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.metrics.window_stat import WindowStat @@ -376,8 +376,7 @@ def _encode_sample(self, idxes: List[int]) -> SampleBatchType: if samples: # We assume all samples are of same type - sample_type = type(samples[0]) - out = sample_type.concat_samples(samples) + out = concat_samples(samples) else: out = SampleBatch() out.decompress_if_needed()