-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Faster flatten_keys #727
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vmoens
commented
Apr 18, 2024
# Conflicts: # tensordict/base.py
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 41.2170μs | 17.6586μs | 56.6296 KOps/s | 56.9657 KOps/s | |
test_plain_set_stack_nested | 46.3370μs | 17.8527μs | 56.0138 KOps/s | 56.4005 KOps/s | |
test_plain_set_nested_inplace | 58.3590μs | 20.1003μs | 49.7504 KOps/s | 50.7481 KOps/s | |
test_plain_set_stack_nested_inplace | 77.9760μs | 20.1047μs | 49.7396 KOps/s | 50.0075 KOps/s | |
test_items | 37.1790μs | 2.5262μs | 395.8584 KOps/s | 406.2550 KOps/s | |
test_items_nested | 1.9384ms | 0.2732ms | 3.6605 KOps/s | 3.6946 KOps/s | |
test_items_nested_locked | 0.4537ms | 0.2858ms | 3.4992 KOps/s | 3.6638 KOps/s | |
test_items_nested_leaf | 0.2044ms | 78.1419μs | 12.7972 KOps/s | 12.9290 KOps/s | |
test_items_stack_nested | 0.3228ms | 0.2752ms | 3.6331 KOps/s | 3.6339 KOps/s | |
test_items_stack_nested_leaf | 0.1392ms | 75.3112μs | 13.2782 KOps/s | 12.9736 KOps/s | |
test_items_stack_nested_locked | 0.4856ms | 0.2761ms | 3.6217 KOps/s | 3.6027 KOps/s | |
test_keys | 34.9360μs | 3.7930μs | 263.6404 KOps/s | 258.6239 KOps/s | |
test_keys_nested | 0.3266ms | 0.1384ms | 7.2235 KOps/s | 6.9943 KOps/s | |
test_keys_nested_locked | 0.7898ms | 0.1399ms | 7.1501 KOps/s | 6.8142 KOps/s | |
test_keys_nested_leaf | 0.2041ms | 0.1145ms | 8.7324 KOps/s | 8.0811 KOps/s | |
test_keys_stack_nested | 0.2525ms | 0.1338ms | 7.4714 KOps/s | 6.8932 KOps/s | |
test_keys_stack_nested_leaf | 0.2063ms | 0.1133ms | 8.8254 KOps/s | 7.8947 KOps/s | |
test_keys_stack_nested_locked | 0.2696ms | 0.1381ms | 7.2419 KOps/s | 6.6474 KOps/s | |
test_values | 23.9045μs | 1.1595μs | 862.4136 KOps/s | 861.9713 KOps/s | |
test_values_nested | 0.1142ms | 50.5380μs | 19.7871 KOps/s | 19.9309 KOps/s | |
test_values_nested_locked | 0.1051ms | 50.9054μs | 19.6443 KOps/s | 19.9264 KOps/s | |
test_values_nested_leaf | 0.1107ms | 45.5222μs | 21.9673 KOps/s | 22.0482 KOps/s | |
test_values_stack_nested | 0.1237ms | 51.9357μs | 19.2546 KOps/s | 19.4256 KOps/s | |
test_values_stack_nested_leaf | 0.1002ms | 45.3305μs | 22.0602 KOps/s | 22.4447 KOps/s | |
test_values_stack_nested_locked | 0.1112ms | 51.4457μs | 19.4380 KOps/s | 19.7086 KOps/s | |
test_membership | 37.5910μs | 1.3673μs | 731.3905 KOps/s | 777.3359 KOps/s | |
test_membership_nested | 36.2370μs | 3.4741μs | 287.8451 KOps/s | 296.5440 KOps/s | |
test_membership_nested_leaf | 25.2770μs | 3.4942μs | 286.1871 KOps/s | 295.7432 KOps/s | |
test_membership_stacked_nested | 46.6370μs | 3.4945μs | 286.1603 KOps/s | 298.2834 KOps/s | |
test_membership_stacked_nested_leaf | 27.5720μs | 3.4595μs | 289.0576 KOps/s | 293.0608 KOps/s | |
test_membership_nested_last | 40.5050μs | 4.2728μs | 234.0398 KOps/s | 239.5227 KOps/s | |
test_membership_nested_leaf_last | 41.9290μs | 4.2866μs | 233.2855 KOps/s | 233.8806 KOps/s | |
test_membership_stacked_nested_last | 44.2730μs | 13.4345μs | 74.4350 KOps/s | 172.0594 KOps/s | |
test_membership_stacked_nested_leaf_last | 44.4430μs | 13.4339μs | 74.4387 KOps/s | 172.3182 KOps/s | |
test_nested_getleaf | 88.6050μs | 10.8557μs | 92.1171 KOps/s | 94.3460 KOps/s | |
test_nested_get | 40.7670μs | 10.0976μs | 99.0336 KOps/s | 99.1298 KOps/s | |
test_stacked_getleaf | 43.6710μs | 10.6102μs | 94.2488 KOps/s | 94.6514 KOps/s | |
test_stacked_get | 71.6840μs | 10.1587μs | 98.4373 KOps/s | 99.9837 KOps/s | |
test_nested_getitemleaf | 47.4790μs | 11.2098μs | 89.2076 KOps/s | 90.1139 KOps/s | |
test_nested_getitem | 47.3590μs | 10.2827μs | 97.2507 KOps/s | 96.6618 KOps/s | |
test_stacked_getitemleaf | 41.9680μs | 11.0440μs | 90.5468 KOps/s | 90.6824 KOps/s | |
test_stacked_getitem | 74.2290μs | 10.2327μs | 97.7255 KOps/s | 95.9078 KOps/s | |
test_lock_nested | 58.2529ms | 0.4159ms | 2.4041 KOps/s | 2.9550 KOps/s | |
test_lock_stack_nested | 0.4824ms | 0.3030ms | 3.3005 KOps/s | 3.3253 KOps/s | |
test_unlock_nested | 0.1019s | 0.4595ms | 2.1762 KOps/s | 2.2504 KOps/s | |
test_unlock_stack_nested | 0.5029ms | 0.3111ms | 3.2140 KOps/s | 3.2450 KOps/s | |
test_flatten_speed | 0.6082ms | 91.8322μs | 10.8894 KOps/s | 10.9993 KOps/s | |
test_unflatten_speed | 0.8223ms | 0.4053ms | 2.4670 KOps/s | 2.4865 KOps/s | |
test_common_ops | 4.5387ms | 0.7487ms | 1.3357 KOps/s | 1.3871 KOps/s | |
test_creation | 0.1153ms | 1.8052μs | 553.9597 KOps/s | 554.8262 KOps/s | |
test_creation_empty | 47.4980μs | 11.4409μs | 87.4056 KOps/s | 87.3943 KOps/s | |
test_creation_nested_1 | 48.5610μs | 14.1545μs | 70.6490 KOps/s | 70.0839 KOps/s | |
test_creation_nested_2 | 57.7380μs | 17.7079μs | 56.4721 KOps/s | 58.1530 KOps/s | |
test_clone | 0.1667ms | 14.1572μs | 70.6355 KOps/s | 75.6108 KOps/s | |
test_getitem[int] | 37.6200μs | 11.8847μs | 84.1420 KOps/s | 89.8048 KOps/s | |
test_getitem[slice_int] | 55.7250μs | 22.9652μs | 43.5441 KOps/s | 46.8985 KOps/s | |
test_getitem[range] | 0.2320ms | 42.7814μs | 23.3747 KOps/s | 24.9482 KOps/s | |
test_getitem[tuple] | 74.8400μs | 18.9569μs | 52.7512 KOps/s | 55.5917 KOps/s | |
test_getitem[list] | 0.3324ms | 38.1470μs | 26.2144 KOps/s | 26.7849 KOps/s | |
test_setitem_dim[int] | 86.5720μs | 36.4853μs | 27.4083 KOps/s | 29.8857 KOps/s | |
test_setitem_dim[slice_int] | 0.1116ms | 62.7674μs | 15.9318 KOps/s | 16.8976 KOps/s | |
test_setitem_dim[range] | 0.1407ms | 81.7130μs | 12.2380 KOps/s | 13.0080 KOps/s | |
test_setitem_dim[tuple] | 77.3650μs | 51.7488μs | 19.3241 KOps/s | 20.7359 KOps/s | |
test_setitem | 0.1325ms | 22.0146μs | 45.4244 KOps/s | 49.5131 KOps/s | |
test_set | 0.1777ms | 21.4839μs | 46.5466 KOps/s | 50.2191 KOps/s | |
test_set_shared | 2.2951ms | 0.1513ms | 6.6072 KOps/s | 7.0634 KOps/s | |
test_update | 0.1782ms | 23.7635μs | 42.0813 KOps/s | 45.1465 KOps/s | |
test_update_nested | 0.1903ms | 32.3988μs | 30.8653 KOps/s | 33.2724 KOps/s | |
test_update__nested | 0.1462ms | 25.7958μs | 38.7660 KOps/s | 40.9142 KOps/s | |
test_set_nested | 0.1736ms | 23.7439μs | 42.1161 KOps/s | 46.0792 KOps/s | |
test_set_nested_new | 0.1679ms | 27.4375μs | 36.4464 KOps/s | 38.5667 KOps/s | |
test_select | 0.1538ms | 42.4735μs | 23.5441 KOps/s | 24.8289 KOps/s | |
test_select_nested | 0.1247ms | 60.5053μs | 16.5275 KOps/s | 17.1432 KOps/s | |
test_exclude_nested | 0.2244ms | 0.1215ms | 8.2338 KOps/s | 8.5277 KOps/s | |
test_empty[True] | 0.6262ms | 0.3943ms | 2.5359 KOps/s | 2.4159 KOps/s | |
test_empty[False] | 7.3558μs | 1.0727μs | 932.2125 KOps/s | 983.6447 KOps/s | |
test_unbind_speed | 1.7495ms | 0.2588ms | 3.8638 KOps/s | 4.0215 KOps/s | |
test_unbind_speed_stack0 | 0.4004ms | 0.2464ms | 4.0592 KOps/s | 4.1255 KOps/s | |
test_unbind_speed_stack1 | 0.1469s | 0.7012ms | 1.4262 KOps/s | 1.4759 KOps/s | |
test_split | 1.7748ms | 1.5231ms | 656.5521 Ops/s | 602.1056 Ops/s | |
test_chunk | 0.1451s | 1.7799ms | 561.8283 Ops/s | 684.9268 Ops/s | |
test_creation[device0] | 6.2002ms | 0.1085ms | 9.2141 KOps/s | 9.8194 KOps/s | |
test_creation_from_tensor | 0.3128ms | 84.3574μs | 11.8543 KOps/s | 12.0007 KOps/s | |
test_add_one[memmap_tensor0] | 0.1906ms | 5.7311μs | 174.4861 KOps/s | 182.4240 KOps/s | |
test_contiguous[memmap_tensor0] | 25.6380μs | 0.6358μs | 1.5729 MOps/s | 1.5766 MOps/s | |
test_stack[memmap_tensor0] | 38.8120μs | 3.9714μs | 251.8005 KOps/s | 279.0343 KOps/s | |
test_memmaptd_index | 1.2068ms | 0.2441ms | 4.0958 KOps/s | 4.3798 KOps/s | |
test_memmaptd_index_astensor | 0.5688ms | 0.3068ms | 3.2595 KOps/s | 3.2679 KOps/s | |
test_memmaptd_index_op | 1.2065ms | 0.6593ms | 1.5168 KOps/s | 1.6408 KOps/s | |
test_serialize_model | 0.1122s | 0.1036s | 9.6517 Ops/s | 8.6982 Ops/s | |
test_serialize_model_pickle | 0.5890s | 0.3868s | 2.5854 Ops/s | 2.6135 Ops/s | |
test_serialize_weights | 0.1050s | 99.1323ms | 10.0875 Ops/s | 10.0000 Ops/s | |
test_serialize_weights_returnearly | 0.1319s | 0.1250s | 7.9996 Ops/s | 8.1036 Ops/s | |
test_serialize_weights_pickle | 1.1066s | 0.5686s | 1.7588 Ops/s | 1.5513 Ops/s | |
test_serialize_weights_filesystem | 91.8738ms | 88.8080ms | 11.2603 Ops/s | 11.1692 Ops/s | |
test_serialize_model_filesystem | 0.1033s | 94.0672ms | 10.6307 Ops/s | 10.7787 Ops/s | |
test_reshape_pytree | 48.7910μs | 21.2871μs | 46.9767 KOps/s | 49.1469 KOps/s | |
test_reshape_td | 68.8490μs | 32.7807μs | 30.5057 KOps/s | 32.0559 KOps/s | |
test_view_pytree | 54.3010μs | 21.0891μs | 47.4178 KOps/s | 48.9576 KOps/s | |
test_view_td | 0.1224s | 63.0740μs | 15.8544 KOps/s | 16.0418 KOps/s | |
test_unbind_pytree | 66.2540μs | 24.4584μs | 40.8857 KOps/s | 40.8435 KOps/s | |
test_unbind_td | 0.5306ms | 37.3979μs | 26.7395 KOps/s | 27.7590 KOps/s | |
test_split_pytree | 54.1110μs | 24.4954μs | 40.8239 KOps/s | 43.4516 KOps/s | |
test_split_td | 0.1218ms | 41.0355μs | 24.3691 KOps/s | 25.5260 KOps/s | |
test_add_pytree | 79.6590μs | 30.1942μs | 33.1190 KOps/s | 33.5007 KOps/s | |
test_add_td | 0.1171ms | 57.1346μs | 17.5025 KOps/s | 17.9561 KOps/s | |
test_distributed | 0.2594ms | 99.7724μs | 10.0228 KOps/s | 9.8067 KOps/s | |
test_tdmodule | 33.4430μs | 17.7709μs | 56.2719 KOps/s | 56.3376 KOps/s | |
test_tdmodule_dispatch | 70.5020μs | 36.0207μs | 27.7618 KOps/s | 27.6773 KOps/s | |
test_tdseq | 45.3450μs | 20.7877μs | 48.1054 KOps/s | 48.7830 KOps/s | |
test_tdseq_dispatch | 65.9730μs | 40.6388μs | 24.6070 KOps/s | 24.8227 KOps/s | |
test_instantiation_functorch | 1.4182ms | 1.3031ms | 767.3957 Ops/s | 778.5141 Ops/s | |
test_instantiation_td | 1.6620ms | 1.0093ms | 990.8087 Ops/s | 1.0106 KOps/s | |
test_exec_functorch | 0.2871ms | 0.1572ms | 6.3598 KOps/s | 6.3004 KOps/s | |
test_exec_functional_call | 0.2804ms | 0.1470ms | 6.8030 KOps/s | 6.7665 KOps/s | |
test_exec_td | 0.2269ms | 0.1416ms | 7.0643 KOps/s | 7.1038 KOps/s | |
test_exec_td_decorator | 0.7598ms | 0.1965ms | 5.0878 KOps/s | 5.1653 KOps/s | |
test_vmap_mlp_speed[True-True] | 1.2726ms | 0.4879ms | 2.0497 KOps/s | 2.1187 KOps/s | |
test_vmap_mlp_speed[True-False] | 0.7031ms | 0.4785ms | 2.0901 KOps/s | 2.1313 KOps/s | |
test_vmap_mlp_speed[False-True] | 0.7205ms | 0.3879ms | 2.5778 KOps/s | 2.6379 KOps/s | |
test_vmap_mlp_speed[False-False] | 0.5993ms | 0.3834ms | 2.6082 KOps/s | 2.6365 KOps/s | |
test_vmap_mlp_speed_decorator[True-True] | 1.2410ms | 0.5008ms | 1.9966 KOps/s | 2.0237 KOps/s | |
test_vmap_mlp_speed_decorator[True-False] | 0.7336ms | 0.4975ms | 2.0099 KOps/s | 2.0341 KOps/s | |
test_vmap_mlp_speed_decorator[False-True] | 1.1350ms | 0.4085ms | 2.4477 KOps/s | 2.4989 KOps/s | |
test_vmap_mlp_speed_decorator[False-False] | 0.7120ms | 0.4036ms | 2.4777 KOps/s | 2.5068 KOps/s | |
test_to_module_speed[True] | 2.2566ms | 1.4340ms | 697.3707 Ops/s | 726.2776 Ops/s | |
test_to_module_speed[False] | 1.5380ms | 1.3924ms | 718.1925 Ops/s | 727.6538 Ops/s |
This PR has broken the following script from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.modules import VDNMixer, MultiAgentMLP, QValueModule, SafeSequential
from torchrl.objectives import QMixerLoss
net = MultiAgentMLP(
n_agent_inputs=3,
n_agent_outputs=4,
n_agents=2,
centralised=False,
share_params=True,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
)
module = TensorDictModule(
net, in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")]
)
value_module = QValueModule(
action_value_key=("agents", "action_value"),
out_keys=[
("agents", "action"),
("agents", "action_value"),
("agents", "chosen_action_value"),
],
action_space="categorical",
)
qnet = SafeSequential(module, value_module)
mixer = TensorDictModule(
module=VDNMixer(
n_agents=2,
device="cpu",
),
in_keys=[("agents", "chosen_action_value")],
out_keys=["chosen_action_value"],
)
loss_module = QMixerLoss(qnet, mixer, delay_value=True, action_space="categorical")
loss_module.state_dict() Traceback (most recent call last):
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/prova.py", line 44, in <module>
loss_module.state_dict()
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1895, in state_dict
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 996, in state_dict
return self._param_td.state_dict(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 1903, in state_dict
source = source.flatten_keys(".")
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1060, in newfun
return fun(_self, *args, **kwargs)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 6215, in flatten_keys
return self._flatten_keys_outplace(separator=separator, is_leaf=is_leaf)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 6220, in _flatten_keys_outplace
all_leaves, all_vals = zip(
ValueError: not enough values to unpack (expected 2, got 0) tested to work before the PR |
Got it |
(additionally i would post it as an issue and not a comment, there's a high probability I miss comments on closed PRs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Performance
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.