Skip to content

Commit

Permalink
[RLlib] Clean up deprecated concat_samples calls (#31391)
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
  • Loading branch information
ArturNiederfahrenhorst authored Jan 5, 2023
1 parent 1abbbd7 commit f44f578
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 23 deletions.
9 changes: 4 additions & 5 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from ray.rllib.offline import InputReader
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
Expand Down Expand Up @@ -91,10 +91,9 @@ class SamplerInput(InputReader, metaclass=ABCMeta):
def next(self) -> SampleBatchType:
batches = [self.get_data()]
batches.extend(self.get_extra_batches())
if len(batches) > 1:
return batches[0].concat_samples(batches)
else:
return batches[0]
if len(batches) == 0:
raise RuntimeError("No data available from sampler.")
return concat_samples(batches)

@abstractmethod
@DeveloperAPI
Expand Down
6 changes: 3 additions & 3 deletions rllib/execution/buffers/mixin_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ray.util.timer import _Timer
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, concat_samples
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
Expand Down Expand Up @@ -155,7 +155,7 @@ def replay(

# No replay desired -> Return here.
if self.replay_ratio == 0.0:
return SampleBatch.concat_samples(output_batches)
return concat_samples(output_batches)
# Only replay desired -> Return a (replayed) sample from the
# buffer.
elif self.replay_ratio == 1.0:
Expand All @@ -168,7 +168,7 @@ def replay(
while random.random() < num_new * replay_proportion:
replay_proportion -= 1
output_batches.append(buffer.replay())
return SampleBatch.concat_samples(output_batches)
return concat_samples(output_batches)

def get_host(self) -> str:
"""Returns the computer's network name.
Expand Down
4 changes: 2 additions & 2 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def concat(self, other: "SampleBatch") -> "SampleBatch":
>>> print(b1.concat(b2)) # doctest: +SKIP
{"a": np.array([1, 2, 3, 4, 5])}
"""
return self.concat_samples([self, other])
return concat_samples([self, other])

@PublicAPI
def copy(self, shallow: bool = False) -> "SampleBatch":
Expand Down Expand Up @@ -1376,7 +1376,7 @@ def wrap_as_needed(

@staticmethod
@PublicAPI
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=False)
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
return concat_samples_into_ma_batch(samples)

Expand Down
14 changes: 9 additions & 5 deletions rllib/policy/tests/test_sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

import ray
from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.policy.sample_batch import SampleBatch, attempt_count_timesteps
from ray.rllib.policy.sample_batch import (
SampleBatch,
attempt_count_timesteps,
concat_samples,
)
from ray.rllib.utils.compression import is_compressed
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.test_utils import check
Expand Down Expand Up @@ -84,7 +88,7 @@ def test_right_zero_padding(self):
)

def test_concat(self):
"""Tests, SampleBatches.concat() and ...concat_samples()."""
"""Tests, SampleBatches.concat() and concat_samples()."""
s1 = SampleBatch(
{
"a": np.array([1, 2, 3]),
Expand All @@ -97,7 +101,7 @@ def test_concat(self):
"b": {"c": np.array([5, 6, 7])},
}
)
concatd = SampleBatch.concat_samples([s1, s2])
concatd = concat_samples([s1, s2])
check(concatd["a"], [1, 2, 3, 2, 3, 4])
check(concatd["b"]["c"], [4, 5, 6, 5, 6, 7])
check(next(concatd.rows()), {"a": 1, "b": {"c": 4}})
Expand Down Expand Up @@ -129,11 +133,11 @@ def test_concat_max_seq_len(self):
}
)

concatd = SampleBatch.concat_samples([s1, s2])
concatd = concat_samples([s1, s2])
check(concatd.max_seq_len, s2.max_seq_len)

with self.assertRaises(ValueError):
SampleBatch.concat_samples([s1, s2, s3])
concat_samples([s1, s2, s3])

def test_rows(self):
s1 = SampleBatch(
Expand Down
9 changes: 4 additions & 5 deletions rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
MultiAgentBatch,
SampleBatch,
concat_samples,
concat_samples_into_ma_batch,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
Expand Down Expand Up @@ -291,7 +290,7 @@ def round_up_or_down(value, ratio):

# No replay desired
if self.replay_ratio == 0.0:
return concat_samples(output_batches)
return concat_samples_into_ma_batch(output_batches)
# Only replay desired
elif self.replay_ratio == 1.0:
return _buffer.sample(num_items, **kwargs)
Expand All @@ -315,7 +314,7 @@ def round_up_or_down(value, ratio):
# Depending on the implementation of underlying buffers, samples
# might be SampleBatches
output_batches = [batch.as_multi_agent() for batch in output_batches]
return MultiAgentBatch.concat_samples(output_batches)
return concat_samples_into_ma_batch(output_batches)

def check_buffer_is_ready(_policy_id):
if (
Expand Down Expand Up @@ -344,7 +343,7 @@ def check_buffer_is_ready(_policy_id):
if check_buffer_is_ready(policy_id):
samples.append(mix_batches(policy_id).as_multi_agent())

return MultiAgentBatch.concat_samples(samples)
return concat_samples_into_ma_batch(samples)

@DeveloperAPI
@override(MultiAgentPrioritizedReplayBuffer)
Expand Down
5 changes: 2 additions & 3 deletions rllib/utils/replay_buffers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ray # noqa F401
import psutil

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
from ray.rllib.utils.actor_manager import FaultAwareApply
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.window_stat import WindowStat
Expand Down Expand Up @@ -376,8 +376,7 @@ def _encode_sample(self, idxes: List[int]) -> SampleBatchType:

if samples:
# We assume all samples are of same type
sample_type = type(samples[0])
out = sample_type.concat_samples(samples)
out = concat_samples(samples)
else:
out = SampleBatch()
out.decompress_if_needed()
Expand Down

0 comments on commit f44f578

Please sign in to comment.