diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 03abce007ca23..0f6221e595348 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -56,7 +56,10 @@ def attempt_count_timesteps(tensor_dict: dict): and not (tf and tf.is_tensor(tensor_dict[SampleBatch.SEQ_LENS])) and len(tensor_dict[SampleBatch.SEQ_LENS]) > 0 ): - return sum(tensor_dict[SampleBatch.SEQ_LENS]) + if torch and torch.is_tensor(tensor_dict[SampleBatch.SEQ_LENS]): + return tensor_dict[SampleBatch.SEQ_LENS].sum().item() + else: + return sum(tensor_dict[SampleBatch.SEQ_LENS]) for k, v in copy_.items(): assert isinstance(k, str), tensor_dict @@ -269,7 +272,10 @@ def __init__(self, *args, **kwargs): and not (tf and tf.is_tensor(seq_lens_)) and len(seq_lens_) > 0 ): - self.max_seq_len = max(seq_lens_) + if torch and torch.is_tensor(seq_lens_): + self.max_seq_len = seq_lens_.max().item() + else: + self.max_seq_len = max(seq_lens_) if self._is_training is None: self._is_training = self.pop("is_training", False)