Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Load tensordicts on device, incl. meta #769

Merged
merged 8 commits into from
May 1, 2024
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Apr 30, 2024

Allows to load (asynchronously) a tensordict on device, or on "meta" (ie creating a meta-tensordict without reading any data except metadata)

Test the feature:

import tempfile

from tensordict import TensorDict, is_tensor_collection
from torch.nn import Transformer
import torch

with tempfile.TemporaryDirectory() as tmpdir:
    t = Transformer(
        d_model=64, nhead=4, num_encoder_layers=3, dim_feedforward=128
    )

    state_dict = TensorDict.from_module(t)
    state_dict.data.zero_()

    state_dict.save(tmpdir)
    meta_state_dict = TensorDict.load(tmpdir, device="meta")


    def check_meta(tensor):
        assert tensor.device == torch.device("meta")


    meta_state_dict.apply(check_meta, filter_empty=True)

    if torch.cuda.is_available():
        device = "cuda:0"
    elif torch.backends.mps.is_available():
        device = "mps:0"
    device_state_dict = TensorDict.load(tmpdir, device=device)
    assert (device_state_dict == 0).all()


    def assert_device(item):
        assert item.device == torch.device(device), (device, item.device)
        if is_tensor_collection(item):
            item.apply(assert_device, filter_empty=True, call_on_nested=True)

    device_state_dict.apply(assert_device, filter_empty=True, call_on_nested=True)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2024
@vmoens vmoens added the enhancement New feature or request label Apr 30, 2024
Copy link

github-actions bot commented Apr 30, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 127. Improved: $\large\color{#35bf28}5$. Worsened: $\large\color{#d91a1a}10$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 37.4100μs 17.8984μs 55.8709 KOps/s 56.7475 KOps/s $\color{#d91a1a}-1.54\%$
test_plain_set_stack_nested 53.3310μs 17.9306μs 55.7704 KOps/s 55.5188 KOps/s $\color{#35bf28}+0.45\%$
test_plain_set_nested_inplace 0.6916ms 20.1364μs 49.6612 KOps/s 49.3946 KOps/s $\color{#35bf28}+0.54\%$
test_plain_set_stack_nested_inplace 60.5220μs 20.0654μs 49.8370 KOps/s 49.6040 KOps/s $\color{#35bf28}+0.47\%$
test_items 29.0740μs 2.6336μs 379.7088 KOps/s 381.5479 KOps/s $\color{#d91a1a}-0.48\%$
test_items_nested 0.5063ms 0.2667ms 3.7489 KOps/s 3.7184 KOps/s $\color{#35bf28}+0.82\%$
test_items_nested_locked 1.4138ms 0.2797ms 3.5755 KOps/s 3.6991 KOps/s $\color{#d91a1a}-3.34\%$
test_items_nested_leaf 0.2243ms 79.0806μs 12.6453 KOps/s 12.8897 KOps/s $\color{#d91a1a}-1.90\%$
test_items_stack_nested 1.3171ms 0.2723ms 3.6725 KOps/s 3.7065 KOps/s $\color{#d91a1a}-0.92\%$
test_items_stack_nested_leaf 0.1595ms 80.0140μs 12.4978 KOps/s 12.8574 KOps/s $\color{#d91a1a}-2.80\%$
test_items_stack_nested_locked 0.3370ms 0.2719ms 3.6784 KOps/s 3.7223 KOps/s $\color{#d91a1a}-1.18\%$
test_keys 21.0390μs 3.8061μs 262.7362 KOps/s 256.9631 KOps/s $\color{#35bf28}+2.25\%$
test_keys_nested 0.2667ms 0.1422ms 7.0324 KOps/s 7.1193 KOps/s $\color{#d91a1a}-1.22\%$
test_keys_nested_locked 0.8220ms 0.1474ms 6.7828 KOps/s 6.9119 KOps/s $\color{#d91a1a}-1.87\%$
test_keys_nested_leaf 0.3268ms 0.1221ms 8.1925 KOps/s 8.4375 KOps/s $\color{#d91a1a}-2.90\%$
test_keys_stack_nested 0.2270ms 0.1423ms 7.0252 KOps/s 7.0912 KOps/s $\color{#d91a1a}-0.93\%$
test_keys_stack_nested_leaf 0.2361ms 0.1213ms 8.2450 KOps/s 8.5230 KOps/s $\color{#d91a1a}-3.26\%$
test_keys_stack_nested_locked 0.2833ms 0.1481ms 6.7513 KOps/s 6.8880 KOps/s $\color{#d91a1a}-1.98\%$
test_values 6.2256μs 1.1020μs 907.4521 KOps/s 859.9119 KOps/s $\textbf{\color{#35bf28}+5.53\%}$
test_values_nested 0.2030ms 50.9368μs 19.6322 KOps/s 19.2019 KOps/s $\color{#35bf28}+2.24\%$
test_values_nested_locked 0.1328ms 51.0195μs 19.6004 KOps/s 18.9853 KOps/s $\color{#35bf28}+3.24\%$
test_values_nested_leaf 97.4010μs 45.6696μs 21.8964 KOps/s 21.3279 KOps/s $\color{#35bf28}+2.67\%$
test_values_stack_nested 90.2780μs 52.1365μs 19.1804 KOps/s 19.3124 KOps/s $\color{#d91a1a}-0.68\%$
test_values_stack_nested_leaf 0.1080ms 46.3708μs 21.5653 KOps/s 21.2239 KOps/s $\color{#35bf28}+1.61\%$
test_values_stack_nested_locked 0.1046ms 51.8885μs 19.2721 KOps/s 19.3371 KOps/s $\color{#d91a1a}-0.34\%$
test_membership 27.3100μs 1.3541μs 738.4861 KOps/s 725.4376 KOps/s $\color{#35bf28}+1.80\%$
test_membership_nested 33.2520μs 3.7024μs 270.0983 KOps/s 286.3204 KOps/s $\textbf{\color{#d91a1a}-5.67\%}$
test_membership_nested_leaf 35.1560μs 3.6796μs 271.7662 KOps/s 286.0181 KOps/s $\color{#d91a1a}-4.98\%$
test_membership_stacked_nested 43.4910μs 3.6692μs 272.5379 KOps/s 254.1346 KOps/s $\textbf{\color{#35bf28}+7.24\%}$
test_membership_stacked_nested_leaf 41.4370μs 3.6669μs 272.7078 KOps/s 288.5534 KOps/s $\textbf{\color{#d91a1a}-5.49\%}$
test_membership_nested_last 24.4150μs 4.4649μs 223.9678 KOps/s 237.7275 KOps/s $\textbf{\color{#d91a1a}-5.79\%}$
test_membership_nested_leaf_last 34.2840μs 4.4652μs 223.9536 KOps/s 239.2514 KOps/s $\textbf{\color{#d91a1a}-6.39\%}$
test_membership_stacked_nested_last 36.7890μs 4.4115μs 226.6809 KOps/s 238.8063 KOps/s $\textbf{\color{#d91a1a}-5.08\%}$
test_membership_stacked_nested_leaf_last 25.7980μs 4.5015μs 222.1471 KOps/s 238.9692 KOps/s $\textbf{\color{#d91a1a}-7.04\%}$
test_nested_getleaf 43.8310μs 10.6842μs 93.5958 KOps/s 93.3048 KOps/s $\color{#35bf28}+0.31\%$
test_nested_get 44.0510μs 9.9839μs 100.1616 KOps/s 99.6492 KOps/s $\color{#35bf28}+0.51\%$
test_stacked_getleaf 40.7250μs 10.6086μs 94.2629 KOps/s 94.9651 KOps/s $\color{#d91a1a}-0.74\%$
test_stacked_get 32.9510μs 10.0434μs 99.5679 KOps/s 99.8030 KOps/s $\color{#d91a1a}-0.24\%$
test_nested_getitemleaf 40.7860μs 11.1209μs 89.9211 KOps/s 89.8190 KOps/s $\color{#35bf28}+0.11\%$
test_nested_getitem 45.1540μs 10.3044μs 97.0458 KOps/s 97.3173 KOps/s $\color{#d91a1a}-0.28\%$
test_stacked_getitemleaf 38.9720μs 11.0852μs 90.2100 KOps/s 89.5007 KOps/s $\color{#35bf28}+0.79\%$
test_stacked_getitem 28.2820μs 10.2712μs 97.3595 KOps/s 98.0861 KOps/s $\color{#d91a1a}-0.74\%$
test_lock_nested 52.0684ms 0.4223ms 2.3679 KOps/s 2.8480 KOps/s $\textbf{\color{#d91a1a}-16.86\%}$
test_lock_stack_nested 0.4841ms 0.3297ms 3.0326 KOps/s 3.2067 KOps/s $\textbf{\color{#d91a1a}-5.43\%}$
test_unlock_nested 0.7055ms 0.3687ms 2.7124 KOps/s 2.5170 KOps/s $\textbf{\color{#35bf28}+7.76\%}$
test_unlock_stack_nested 0.4945ms 0.3375ms 2.9630 KOps/s 3.1348 KOps/s $\textbf{\color{#d91a1a}-5.48\%}$
test_flatten_speed 0.2077ms 96.0000μs 10.4167 KOps/s 10.5307 KOps/s $\color{#d91a1a}-1.08\%$
test_unflatten_speed 0.8602ms 0.4253ms 2.3514 KOps/s 2.4433 KOps/s $\color{#d91a1a}-3.76\%$
test_common_ops 2.8323ms 0.7427ms 1.3465 KOps/s 1.3583 KOps/s $\color{#d91a1a}-0.87\%$
test_creation 44.7440μs 1.8956μs 527.5454 KOps/s 512.6375 KOps/s $\color{#35bf28}+2.91\%$
test_creation_empty 77.5790μs 11.9266μs 83.8459 KOps/s 83.1597 KOps/s $\color{#35bf28}+0.83\%$
test_creation_nested_1 42.3980μs 14.7304μs 67.8867 KOps/s 68.2065 KOps/s $\color{#d91a1a}-0.47\%$
test_creation_nested_2 40.9960μs 18.1431μs 55.1173 KOps/s 55.5236 KOps/s $\color{#d91a1a}-0.73\%$
test_clone 65.8830μs 13.7219μs 72.8764 KOps/s 74.6507 KOps/s $\color{#d91a1a}-2.38\%$
test_getitem[int] 36.8590μs 11.4128μs 87.6206 KOps/s 85.8130 KOps/s $\color{#35bf28}+2.11\%$
test_getitem[slice_int] 50.2430μs 22.5617μs 44.3228 KOps/s 43.7811 KOps/s $\color{#35bf28}+1.24\%$
test_getitem[range] 79.2180μs 59.3226μs 16.8570 KOps/s 16.8484 KOps/s $\color{#35bf28}+0.05\%$
test_getitem[tuple] 52.0770μs 19.0347μs 52.5355 KOps/s 53.2406 KOps/s $\color{#d91a1a}-1.32\%$
test_getitem[list] 93.9750μs 41.0377μs 24.3678 KOps/s 24.3018 KOps/s $\color{#35bf28}+0.27\%$
test_setitem_dim[int] 58.4780μs 36.0431μs 27.7446 KOps/s 27.9092 KOps/s $\color{#d91a1a}-0.59\%$
test_setitem_dim[slice_int] 0.1436ms 63.9764μs 15.6308 KOps/s 15.5349 KOps/s $\color{#35bf28}+0.62\%$
test_setitem_dim[range] 0.1329ms 86.6285μs 11.5435 KOps/s 11.4277 KOps/s $\color{#35bf28}+1.01\%$
test_setitem_dim[tuple] 86.2810μs 50.8199μs 19.6773 KOps/s 19.2202 KOps/s $\color{#35bf28}+2.38\%$
test_setitem 65.9830μs 21.1709μs 47.2346 KOps/s 46.5475 KOps/s $\color{#35bf28}+1.48\%$
test_set 61.3950μs 20.8713μs 47.9126 KOps/s 48.3774 KOps/s $\color{#d91a1a}-0.96\%$
test_set_shared 1.6524ms 0.1409ms 7.0966 KOps/s 7.3489 KOps/s $\color{#d91a1a}-3.43\%$
test_update 91.7710μs 23.6372μs 42.3062 KOps/s 42.7398 KOps/s $\color{#d91a1a}-1.01\%$
test_update_nested 0.1147ms 32.6034μs 30.6716 KOps/s 30.9997 KOps/s $\color{#d91a1a}-1.06\%$
test_update__nested 78.1550μs 25.0180μs 39.9712 KOps/s 39.0325 KOps/s $\color{#35bf28}+2.40\%$
test_set_nested 62.9870μs 22.4765μs 44.4908 KOps/s 44.6204 KOps/s $\color{#d91a1a}-0.29\%$
test_set_nested_new 86.7520μs 26.9589μs 37.0935 KOps/s 37.9688 KOps/s $\color{#d91a1a}-2.31\%$
test_select 0.1156ms 42.0098μs 23.8039 KOps/s 24.0347 KOps/s $\color{#d91a1a}-0.96\%$
test_select_nested 0.1232ms 59.8422μs 16.7106 KOps/s 16.1709 KOps/s $\color{#35bf28}+3.34\%$
test_exclude_nested 0.2532ms 0.1206ms 8.2928 KOps/s 8.3479 KOps/s $\color{#d91a1a}-0.66\%$
test_empty[True] 1.1300ms 0.4036ms 2.4779 KOps/s 2.5006 KOps/s $\color{#d91a1a}-0.91\%$
test_empty[False] 6.4420μs 1.0862μs 920.6639 KOps/s 898.9760 KOps/s $\color{#35bf28}+2.41\%$
test_unbind_speed 0.3347ms 0.2710ms 3.6895 KOps/s 3.8260 KOps/s $\color{#d91a1a}-3.57\%$
test_unbind_speed_stack0 0.4502ms 0.2677ms 3.7350 KOps/s 3.8829 KOps/s $\color{#d91a1a}-3.81\%$
test_unbind_speed_stack1 68.0438ms 0.7833ms 1.2767 KOps/s 1.2397 KOps/s $\color{#35bf28}+2.98\%$
test_split 61.5446ms 1.5822ms 632.0393 Ops/s 608.2814 Ops/s $\color{#35bf28}+3.91\%$
test_chunk 72.6391ms 1.6140ms 619.5828 Ops/s 659.2262 Ops/s $\textbf{\color{#d91a1a}-6.01\%}$
test_creation[device0] 0.1793ms 0.1046ms 9.5594 KOps/s 9.7864 KOps/s $\color{#d91a1a}-2.32\%$
test_creation_from_tensor 3.6569ms 84.1832μs 11.8789 KOps/s 12.0541 KOps/s $\color{#d91a1a}-1.45\%$
test_add_one[memmap_tensor0] 72.5750μs 5.5094μs 181.5082 KOps/s 183.7081 KOps/s $\color{#d91a1a}-1.20\%$
test_contiguous[memmap_tensor0] 18.8150μs 0.6375μs 1.5686 MOps/s 1.5878 MOps/s $\color{#d91a1a}-1.21\%$
test_stack[memmap_tensor0] 30.0660μs 3.5058μs 285.2422 KOps/s 287.9859 KOps/s $\color{#d91a1a}-0.95\%$
test_memmaptd_index 1.0129ms 0.2385ms 4.1937 KOps/s 4.2104 KOps/s $\color{#d91a1a}-0.40\%$
test_memmaptd_index_astensor 0.9020ms 0.3194ms 3.1311 KOps/s 2.9337 KOps/s $\textbf{\color{#35bf28}+6.73\%}$
test_memmaptd_index_op 1.0264ms 0.6280ms 1.5922 KOps/s 1.5918 KOps/s $\color{#35bf28}+0.03\%$
test_serialize_model 0.1821s 0.1171s 8.5407 Ops/s 8.5607 Ops/s $\color{#d91a1a}-0.23\%$
test_serialize_model_pickle 0.4463s 0.3808s 2.6262 Ops/s 2.5975 Ops/s $\color{#35bf28}+1.11\%$
test_serialize_weights 0.1876s 0.1154s 8.6634 Ops/s 8.6937 Ops/s $\color{#d91a1a}-0.35\%$
test_serialize_weights_returnearly 0.1359s 0.1259s 7.9404 Ops/s 7.7088 Ops/s $\color{#35bf28}+3.00\%$
test_serialize_weights_pickle 0.5355s 0.4107s 2.4348 Ops/s 2.4209 Ops/s $\color{#35bf28}+0.58\%$
test_serialize_weights_filesystem 0.1012s 96.4884ms 10.3639 Ops/s 10.8393 Ops/s $\color{#d91a1a}-4.39\%$
test_serialize_model_filesystem 0.1025s 97.3914ms 10.2679 Ops/s 10.4796 Ops/s $\color{#d91a1a}-2.02\%$
test_reshape_pytree 69.0780μs 25.6150μs 39.0397 KOps/s 39.2921 KOps/s $\color{#d91a1a}-0.64\%$
test_reshape_td 88.7750μs 34.0663μs 29.3545 KOps/s 29.8108 KOps/s $\color{#d91a1a}-1.53\%$
test_view_pytree 59.3700μs 25.3696μs 39.4172 KOps/s 39.3962 KOps/s $\color{#35bf28}+0.05\%$
test_view_td 0.1012ms 37.6121μs 26.5872 KOps/s 26.6472 KOps/s $\color{#d91a1a}-0.23\%$
test_unbind_pytree 83.3650μs 29.1393μs 34.3179 KOps/s 34.4630 KOps/s $\color{#d91a1a}-0.42\%$
test_unbind_td 0.3736ms 39.7856μs 25.1347 KOps/s 26.2554 KOps/s $\color{#d91a1a}-4.27\%$
test_split_pytree 80.5000μs 29.5448μs 33.8469 KOps/s 34.7906 KOps/s $\color{#d91a1a}-2.71\%$
test_split_td 0.1369ms 40.3134μs 24.8056 KOps/s 24.4847 KOps/s $\color{#35bf28}+1.31\%$
test_add_pytree 81.2420μs 34.7709μs 28.7597 KOps/s 28.5663 KOps/s $\color{#35bf28}+0.68\%$
test_add_td 0.1426ms 57.7467μs 17.3170 KOps/s 17.5736 KOps/s $\color{#d91a1a}-1.46\%$
test_distributed 0.1797ms 98.7082μs 10.1309 KOps/s 9.8587 KOps/s $\color{#35bf28}+2.76\%$
test_tdmodule 63.7080μs 18.7400μs 53.3617 KOps/s 52.9901 KOps/s $\color{#35bf28}+0.70\%$
test_tdmodule_dispatch 70.6610μs 37.0445μs 26.9946 KOps/s 26.8974 KOps/s $\color{#35bf28}+0.36\%$
test_tdseq 41.7680μs 21.3622μs 46.8117 KOps/s 46.7671 KOps/s $\color{#35bf28}+0.10\%$
test_tdseq_dispatch 65.3720μs 41.7572μs 23.9479 KOps/s 23.5527 KOps/s $\color{#35bf28}+1.68\%$
test_instantiation_functorch 1.9533ms 1.2953ms 772.0140 Ops/s 757.0879 Ops/s $\color{#35bf28}+1.97\%$
test_instantiation_td 1.6506ms 1.0094ms 990.6463 Ops/s 983.3579 Ops/s $\color{#35bf28}+0.74\%$
test_exec_functorch 0.2650ms 0.1622ms 6.1663 KOps/s 6.1336 KOps/s $\color{#35bf28}+0.53\%$
test_exec_functional_call 0.2848ms 0.1526ms 6.5541 KOps/s 6.6486 KOps/s $\color{#d91a1a}-1.42\%$
test_exec_td 0.2154ms 0.1471ms 6.7991 KOps/s 6.6457 KOps/s $\color{#35bf28}+2.31\%$
test_exec_td_decorator 0.7478ms 0.2247ms 4.4507 KOps/s 4.4363 KOps/s $\color{#35bf28}+0.33\%$
test_vmap_mlp_speed[True-True] 0.9136ms 0.4870ms 2.0534 KOps/s 2.0435 KOps/s $\color{#35bf28}+0.48\%$
test_vmap_mlp_speed[True-False] 0.6913ms 0.4791ms 2.0871 KOps/s 2.0620 KOps/s $\color{#35bf28}+1.22\%$
test_vmap_mlp_speed[False-True] 0.6066ms 0.3898ms 2.5654 KOps/s 2.5374 KOps/s $\color{#35bf28}+1.10\%$
test_vmap_mlp_speed[False-False] 0.5996ms 0.3917ms 2.5527 KOps/s 2.5327 KOps/s $\color{#35bf28}+0.79\%$
test_vmap_mlp_speed_decorator[True-True] 1.0099ms 0.5562ms 1.7981 KOps/s 1.7780 KOps/s $\color{#35bf28}+1.13\%$
test_vmap_mlp_speed_decorator[True-False] 0.8779ms 0.5563ms 1.7975 KOps/s 1.6615 KOps/s $\textbf{\color{#35bf28}+8.18\%}$
test_vmap_mlp_speed_decorator[False-True] 0.5791ms 0.4542ms 2.2019 KOps/s 2.1756 KOps/s $\color{#35bf28}+1.21\%$
test_vmap_mlp_speed_decorator[False-False] 0.8531ms 0.4589ms 2.1790 KOps/s 2.2093 KOps/s $\color{#d91a1a}-1.37\%$
test_to_module_speed[True] 1.8513ms 1.7106ms 584.5839 Ops/s 572.6917 Ops/s $\color{#35bf28}+2.08\%$
test_to_module_speed[False] 1.7969ms 1.6961ms 589.6006 Ops/s 583.8643 Ops/s $\color{#35bf28}+0.98\%$

@vmoens vmoens merged commit f691a35 into main May 1, 2024
35 of 38 checks passed
@vmoens vmoens deleted the load-meta-device branch May 1, 2024 08:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants