Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Clean up deprecated concat_samples calls #31391

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from ray.rllib.offline import InputReader
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
Expand Down Expand Up @@ -91,10 +91,9 @@ class SamplerInput(InputReader, metaclass=ABCMeta):
def next(self) -> SampleBatchType:
batches = [self.get_data()]
batches.extend(self.get_extra_batches())
if len(batches) > 1:
return batches[0].concat_samples(batches)
else:
return batches[0]
if len(batches) == 0:
raise RuntimeError("No data available from sampler.")
return concat_samples(batches)

@abstractmethod
@DeveloperAPI
Expand Down
6 changes: 3 additions & 3 deletions rllib/execution/buffers/mixin_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ray.util.timer import _Timer
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, concat_samples
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
Expand Down Expand Up @@ -155,7 +155,7 @@ def replay(

# No replay desired -> Return here.
if self.replay_ratio == 0.0:
return SampleBatch.concat_samples(output_batches)
return concat_samples(output_batches)
# Only replay desired -> Return a (replayed) sample from the
# buffer.
elif self.replay_ratio == 1.0:
Expand All @@ -168,7 +168,7 @@ def replay(
while random.random() < num_new * replay_proportion:
replay_proportion -= 1
output_batches.append(buffer.replay())
return SampleBatch.concat_samples(output_batches)
return concat_samples(output_batches)

def get_host(self) -> str:
"""Returns the computer's network name.
Expand Down
4 changes: 2 additions & 2 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def concat(self, other: "SampleBatch") -> "SampleBatch":
>>> print(b1.concat(b2)) # doctest: +SKIP
{"a": np.array([1, 2, 3, 4, 5])}
"""
return self.concat_samples([self, other])
return concat_samples([self, other])

@PublicAPI
def copy(self, shallow: bool = False) -> "SampleBatch":
Expand Down Expand Up @@ -1376,7 +1376,7 @@ def wrap_as_needed(

@staticmethod
@PublicAPI
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=False)
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
return concat_samples_into_ma_batch(samples)

Expand Down
14 changes: 9 additions & 5 deletions rllib/policy/tests/test_sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

import ray
from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.policy.sample_batch import SampleBatch, attempt_count_timesteps
from ray.rllib.policy.sample_batch import (
SampleBatch,
attempt_count_timesteps,
concat_samples,
)
from ray.rllib.utils.compression import is_compressed
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.test_utils import check
Expand Down Expand Up @@ -84,7 +88,7 @@ def test_right_zero_padding(self):
)

def test_concat(self):
"""Tests, SampleBatches.concat() and ...concat_samples()."""
"""Tests, SampleBatches.concat() and concat_samples()."""
s1 = SampleBatch(
{
"a": np.array([1, 2, 3]),
Expand All @@ -97,7 +101,7 @@ def test_concat(self):
"b": {"c": np.array([5, 6, 7])},
}
)
concatd = SampleBatch.concat_samples([s1, s2])
concatd = concat_samples([s1, s2])
check(concatd["a"], [1, 2, 3, 2, 3, 4])
check(concatd["b"]["c"], [4, 5, 6, 5, 6, 7])
check(next(concatd.rows()), {"a": 1, "b": {"c": 4}})
Expand Down Expand Up @@ -129,11 +133,11 @@ def test_concat_max_seq_len(self):
}
)

concatd = SampleBatch.concat_samples([s1, s2])
concatd = concat_samples([s1, s2])
check(concatd.max_seq_len, s2.max_seq_len)

with self.assertRaises(ValueError):
SampleBatch.concat_samples([s1, s2, s3])
concat_samples([s1, s2, s3])

def test_rows(self):
s1 = SampleBatch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
MultiAgentBatch,
SampleBatch,
concat_samples,
concat_samples_into_ma_batch,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
Expand Down Expand Up @@ -291,7 +290,7 @@ def round_up_or_down(value, ratio):

# No replay desired
if self.replay_ratio == 0.0:
return concat_samples(output_batches)
return concat_samples_into_ma_batch(output_batches)
# Only replay desired
elif self.replay_ratio == 1.0:
return _buffer.sample(num_items, **kwargs)
Expand All @@ -315,7 +314,7 @@ def round_up_or_down(value, ratio):
# Depending on the implementation of underlying buffers, samples
# might be SampleBatches
output_batches = [batch.as_multi_agent() for batch in output_batches]
return MultiAgentBatch.concat_samples(output_batches)
return concat_samples_into_ma_batch(output_batches)

def check_buffer_is_ready(_policy_id):
if (
Expand Down Expand Up @@ -344,7 +343,7 @@ def check_buffer_is_ready(_policy_id):
if check_buffer_is_ready(policy_id):
samples.append(mix_batches(policy_id).as_multi_agent())

return MultiAgentBatch.concat_samples(samples)
return concat_samples_into_ma_batch(samples)

@DeveloperAPI
@override(MultiAgentPrioritizedReplayBuffer)
Expand Down
5 changes: 2 additions & 3 deletions rllib/utils/replay_buffers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ray # noqa F401
import psutil

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
from ray.rllib.utils.actor_manager import FaultAwareApply
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.window_stat import WindowStat
Expand Down Expand Up @@ -376,8 +376,7 @@ def _encode_sample(self, idxes: List[int]) -> SampleBatchType:

if samples:
# We assume all samples are of same type
sample_type = type(samples[0])
out = sample_type.concat_samples(samples)
out = concat_samples(samples)
else:
out = SampleBatch()
out.decompress_if_needed()
Expand Down