Skip to content

Commit

Permalink
[Rllib] Make 'get_num_parameters' more efficient. (#50923)
Browse files Browse the repository at this point in the history
## Why are these changes needed
The actual parameter count method uses multiple loops to count the
number of parameters. Furthermore it uses `np.prod` and `filter`. The
proposed method in this PR instead uses native torch C++ code to count
parameters in a single loop without filtering.

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [ ] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
  • Loading branch information
simonsays1980 authored Mar 4, 2025
1 parent eedd3a8 commit 2b56e60
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
26 changes: 12 additions & 14 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,20 +583,18 @@ def _map_module_to_device(self, module: MultiRLModule) -> None:
@override(Learner)
def _log_trainable_parameters(self) -> None:
# Log number of non-trainable and trainable parameters of our RLModule.
num_trainable_params = {
(mid, NUM_TRAINABLE_PARAMETERS): sum(
p.numel() for p in rlm.parameters() if p.requires_grad
)
for mid, rlm in self.module._rl_modules.items()
if isinstance(rlm, TorchRLModule)
}
num_non_trainable_params = {
(mid, NUM_NON_TRAINABLE_PARAMETERS): sum(
p.numel() for p in rlm.parameters() if not p.requires_grad
)
for mid, rlm in self.module._rl_modules.items()
if isinstance(rlm, TorchRLModule)
}
num_trainable_params = defaultdict(int)
num_non_trainable_params = defaultdict(int)
for mid, rlm in self.module._rl_modules.items():
if isinstance(rlm, TorchRLModule):
for p in rlm.parameters():
n = p.numel()
if p.requires_grad:
num_trainable_params[(mid, NUM_TRAINABLE_PARAMETERS)] += n
else:
num_non_trainable_params[
(mid, NUM_NON_TRAINABLE_PARAMETERS)
] += n

self.metrics.log_dict(
{
Expand Down
27 changes: 16 additions & 11 deletions rllib/core/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import logging
from typing import Tuple, Union

import numpy as np

from ray.rllib.core.models.base import Model
from ray.rllib.core.models.configs import ModelConfig
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -80,18 +78,25 @@ def forward(

@override(Model)
def get_num_parameters(self) -> Tuple[int, int]:
num_all_params = sum(int(np.prod(p.size())) for p in self.parameters())
trainable_params = filter(lambda p: p.requires_grad, self.parameters())
num_trainable_params = sum(int(np.prod(p.size())) for p in trainable_params)
return (
num_trainable_params,
num_all_params - num_trainable_params,
)
num_trainable_parameters = 0
num_frozen_parameters = 0
for p in self.parameters():
n = p.numel()
if p.requires_grad:
num_trainable_parameters += n
else:
num_frozen_parameters += n
return num_trainable_parameters, num_frozen_parameters

@override(Model)
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)):
trainable_weights = [p for p in self.parameters() if p.requires_grad]
non_trainable_weights = [p for p in self.parameters() if not p.requires_grad]
trainable_weights = []
non_trainable_weights = []
for p in self.parameters():
if p.requires_grad:
trainable_weights.append(p)
else:
non_trainable_weights.append(p)
for i, w in enumerate(trainable_weights + non_trainable_weights):
fill_val = value_sequence[i % len(value_sequence)]
with torch.no_grad():
Expand Down

0 comments on commit 2b56e60

Please sign in to comment.