Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Saving metadata of tensorclass #582

Merged
merged 11 commits into from
Dec 5, 2023
Merged

[Feature] Saving metadata of tensorclass #582

merged 11 commits into from
Dec 5, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Nov 28, 2023

A simple script to test the feature

import tempfile
from tensordict import tensorclass, TensorDict
import torch
import os

def print_directory_tree(path, indent=""):
    """
    Print the directory tree starting from the specified path.

    Parameters:
    - path (str): The path of the directory to print.
    - indent (str): The current indentation level for formatting.
    """
    if os.path.isdir(path):
        print(indent + os.path.basename(path) + "/")
        indent += "    "
        for item in os.listdir(path):
            print_directory_tree(os.path.join(path, item), indent)
    else:
        print(indent + os.path.basename(path))

@tensorclass
class MyClass:
    X: torch.Tensor
    td: TensorDict
    integer: int
    string: str
    dictionary: dict

@tensorclass
class MyOtherClass:
    Y: torch.Tensor

data = MyClass(
    X = torch.randn(10, 3),
    td=TensorDict({"y": torch.randn(10)}, batch_size=[10]),
    integer=3,
    string="a string",
    dictionary={"some_data": "a"},
    batch_size=[],
)

with tempfile.TemporaryDirectory() as tmpdir:
    print(tmpdir)
    data.memmap_(tmpdir)
    print_directory_tree(tmpdir)
    data2 = MyClass.load_memmap(tmpdir)
    print(data2)

    # the original class can be recovered
    data3 = MyOtherClass.load_memmap(tmpdir)
    assert isinstance(data3, MyClass)

cc @janeyx99

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 28, 2023
Copy link

github-actions bot commented Nov 28, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 113. Improved: $\large\color{#35bf28}5$. Worsened: $\large\color{#d91a1a}6$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 28.7440μs 16.2737μs 61.4490 KOps/s 64.2106 KOps/s $\color{#d91a1a}-4.30\%$
test_plain_set_stack_nested 0.1812ms 0.1424ms 7.0215 KOps/s 7.0543 KOps/s $\color{#d91a1a}-0.47\%$
test_plain_set_nested_inplace 61.2640μs 18.1501μs 55.0961 KOps/s 56.1017 KOps/s $\color{#d91a1a}-1.79\%$
test_plain_set_stack_nested_inplace 0.2419ms 0.1753ms 5.7040 KOps/s 5.6925 KOps/s $\color{#35bf28}+0.20\%$
test_items 24.7060μs 2.4554μs 407.2660 KOps/s 393.8116 KOps/s $\color{#35bf28}+3.42\%$
test_items_nested 0.3264ms 0.2635ms 3.7944 KOps/s 3.6875 KOps/s $\color{#35bf28}+2.90\%$
test_items_nested_locked 0.3367ms 0.2672ms 3.7426 KOps/s 3.7914 KOps/s $\color{#d91a1a}-1.29\%$
test_items_nested_leaf 0.5972ms 0.1630ms 6.1359 KOps/s 6.0971 KOps/s $\color{#35bf28}+0.64\%$
test_items_stack_nested 1.6540ms 1.4973ms 667.8511 Ops/s 671.1486 Ops/s $\color{#d91a1a}-0.49\%$
test_items_stack_nested_leaf 1.4634ms 1.3568ms 737.0235 Ops/s 738.1289 Ops/s $\color{#d91a1a}-0.15\%$
test_items_stack_nested_locked 1.9051ms 0.7623ms 1.3118 KOps/s 1.3135 KOps/s $\color{#d91a1a}-0.13\%$
test_keys 40.2450μs 3.9898μs 250.6404 KOps/s 252.7933 KOps/s $\color{#d91a1a}-0.85\%$
test_keys_nested 0.5341ms 0.1404ms 7.1220 KOps/s 6.5694 KOps/s $\textbf{\color{#35bf28}+8.41\%}$
test_keys_nested_locked 0.2023ms 0.1403ms 7.1261 KOps/s 7.0515 KOps/s $\color{#35bf28}+1.06\%$
test_keys_nested_leaf 0.3952ms 0.1403ms 7.1255 KOps/s 7.0022 KOps/s $\color{#35bf28}+1.76\%$
test_keys_stack_nested 1.5291ms 1.4106ms 708.9294 Ops/s 708.7019 Ops/s $\color{#35bf28}+0.03\%$
test_keys_stack_nested_leaf 1.5258ms 1.4135ms 707.4785 Ops/s 708.8014 Ops/s $\color{#d91a1a}-0.19\%$
test_keys_stack_nested_locked 0.8114ms 0.6773ms 1.4764 KOps/s 1.4882 KOps/s $\color{#d91a1a}-0.79\%$
test_values 10.1740μs 1.1334μs 882.2921 KOps/s 848.9607 KOps/s $\color{#35bf28}+3.93\%$
test_values_nested 0.1094ms 49.7868μs 20.0856 KOps/s 19.9977 KOps/s $\color{#35bf28}+0.44\%$
test_values_nested_locked 97.4110μs 50.4433μs 19.8242 KOps/s 20.2633 KOps/s $\color{#d91a1a}-2.17\%$
test_values_nested_leaf 60.3020μs 44.4683μs 22.4879 KOps/s 22.4959 KOps/s $\color{#d91a1a}-0.04\%$
test_values_stack_nested 1.3248ms 1.2026ms 831.5071 Ops/s 824.4695 Ops/s $\color{#35bf28}+0.85\%$
test_values_stack_nested_leaf 1.8181ms 1.1984ms 834.4151 Ops/s 829.2339 Ops/s $\color{#35bf28}+0.62\%$
test_values_stack_nested_locked 0.8919ms 0.5129ms 1.9496 KOps/s 1.9672 KOps/s $\color{#d91a1a}-0.89\%$
test_membership 10.0590μs 1.3539μs 738.6326 KOps/s 758.6119 KOps/s $\color{#d91a1a}-2.63\%$
test_membership_nested 40.1650μs 2.8022μs 356.8587 KOps/s 359.0941 KOps/s $\color{#d91a1a}-0.62\%$
test_membership_nested_leaf 20.9590μs 2.8358μs 352.6392 KOps/s 362.9227 KOps/s $\color{#d91a1a}-2.83\%$
test_membership_stacked_nested 47.2380μs 11.6598μs 85.7646 KOps/s 86.5218 KOps/s $\color{#d91a1a}-0.88\%$
test_membership_stacked_nested_leaf 41.7380μs 11.6148μs 86.0969 KOps/s 86.6417 KOps/s $\color{#d91a1a}-0.63\%$
test_membership_nested_last 52.2980μs 5.9547μs 167.9335 KOps/s 169.5395 KOps/s $\color{#d91a1a}-0.95\%$
test_membership_nested_leaf_last 33.4630μs 5.9893μs 166.9642 KOps/s 170.3411 KOps/s $\color{#d91a1a}-1.98\%$
test_membership_stacked_nested_last 0.3392ms 0.1671ms 5.9862 KOps/s 6.1017 KOps/s $\color{#d91a1a}-1.89\%$
test_membership_stacked_nested_leaf_last 63.5090μs 13.6722μs 73.1413 KOps/s 73.9451 KOps/s $\color{#d91a1a}-1.09\%$
test_nested_getleaf 51.2560μs 10.7564μs 92.9683 KOps/s 95.2261 KOps/s $\color{#d91a1a}-2.37\%$
test_nested_get 45.7460μs 10.1691μs 98.3371 KOps/s 100.7458 KOps/s $\color{#d91a1a}-2.39\%$
test_stacked_getleaf 0.7492ms 0.6575ms 1.5209 KOps/s 1.5393 KOps/s $\color{#d91a1a}-1.19\%$
test_stacked_get 1.5174ms 0.6175ms 1.6194 KOps/s 1.6271 KOps/s $\color{#d91a1a}-0.47\%$
test_nested_getitemleaf 49.3620μs 10.6886μs 93.5576 KOps/s 95.5839 KOps/s $\color{#d91a1a}-2.12\%$
test_nested_getitem 57.8780μs 10.0976μs 99.0337 KOps/s 101.1489 KOps/s $\color{#d91a1a}-2.09\%$
test_stacked_getitemleaf 0.7880ms 0.6539ms 1.5293 KOps/s 1.5523 KOps/s $\color{#d91a1a}-1.48\%$
test_stacked_getitem 1.0303ms 0.6237ms 1.6034 KOps/s 1.6308 KOps/s $\color{#d91a1a}-1.68\%$
test_lock_nested 67.0470ms 0.6250ms 1.6001 KOps/s 1.7797 KOps/s $\textbf{\color{#d91a1a}-10.09\%}$
test_lock_stack_nested 7.9061ms 5.0921ms 196.3810 Ops/s 195.8173 Ops/s $\color{#35bf28}+0.29\%$
test_unlock_nested 0.9450ms 0.4406ms 2.2695 KOps/s 2.2627 KOps/s $\color{#35bf28}+0.30\%$
test_unlock_stack_nested 80.5529ms 7.2300ms 138.3132 Ops/s 137.1413 Ops/s $\color{#35bf28}+0.85\%$
test_flatten_speed 0.5616ms 0.2817ms 3.5497 KOps/s 3.7241 KOps/s $\color{#d91a1a}-4.68\%$
test_unflatten_speed 0.5693ms 0.4601ms 2.1732 KOps/s 2.1984 KOps/s $\color{#d91a1a}-1.14\%$
test_common_ops 6.9304ms 0.6934ms 1.4421 KOps/s 1.4602 KOps/s $\color{#d91a1a}-1.24\%$
test_creation 27.5310μs 2.4374μs 410.2777 KOps/s 403.1889 KOps/s $\color{#35bf28}+1.76\%$
test_creation_empty 22.6220μs 8.7099μs 114.8116 KOps/s 126.0347 KOps/s $\textbf{\color{#d91a1a}-8.90\%}$
test_creation_nested_1 71.9740μs 12.0785μs 82.7915 KOps/s 89.9500 KOps/s $\textbf{\color{#d91a1a}-7.96\%}$
test_creation_nested_2 42.8700μs 15.6544μs 63.8800 KOps/s 67.5940 KOps/s $\textbf{\color{#d91a1a}-5.49\%}$
test_clone 0.2037ms 13.6627μs 73.1922 KOps/s 72.4913 KOps/s $\color{#35bf28}+0.97\%$
test_getitem[int] 0.1302ms 13.4846μs 74.1589 KOps/s 75.7599 KOps/s $\color{#d91a1a}-2.11\%$
test_getitem[slice_int] 87.7650μs 26.0940μs 38.3231 KOps/s 38.8536 KOps/s $\color{#d91a1a}-1.37\%$
test_getitem[range] 96.0100μs 44.7544μs 22.3442 KOps/s 21.8251 KOps/s $\color{#35bf28}+2.38\%$
test_getitem[tuple] 44.5330μs 21.2905μs 46.9693 KOps/s 48.0401 KOps/s $\color{#d91a1a}-2.23\%$
test_getitem[list] 0.4325ms 39.3110μs 25.4382 KOps/s 24.9132 KOps/s $\color{#35bf28}+2.11\%$
test_setitem_dim[int] 66.2740μs 29.4199μs 33.9906 KOps/s 35.5329 KOps/s $\color{#d91a1a}-4.34\%$
test_setitem_dim[slice_int] 0.1054ms 53.0011μs 18.8675 KOps/s 18.9317 KOps/s $\color{#d91a1a}-0.34\%$
test_setitem_dim[range] 0.1169ms 71.9452μs 13.8995 KOps/s 13.7703 KOps/s $\color{#35bf28}+0.94\%$
test_setitem_dim[tuple] 81.6420μs 43.1190μs 23.1916 KOps/s 23.9315 KOps/s $\color{#d91a1a}-3.09\%$
test_setitem 0.1514ms 19.1290μs 52.2767 KOps/s 51.7916 KOps/s $\color{#35bf28}+0.94\%$
test_set 0.1489ms 19.3224μs 51.7533 KOps/s 54.2354 KOps/s $\color{#d91a1a}-4.58\%$
test_set_shared 2.0548ms 0.1400ms 7.1412 KOps/s 6.8000 KOps/s $\textbf{\color{#35bf28}+5.02\%}$
test_update 0.1860ms 19.6954μs 50.7732 KOps/s 51.2903 KOps/s $\color{#d91a1a}-1.01\%$
test_update_nested 0.1895ms 27.2413μs 36.7090 KOps/s 36.9952 KOps/s $\color{#d91a1a}-0.77\%$
test_set_nested 0.1962ms 20.4529μs 48.8928 KOps/s 48.9899 KOps/s $\color{#d91a1a}-0.20\%$
test_set_nested_new 0.1322ms 25.5162μs 39.1908 KOps/s 39.5089 KOps/s $\color{#d91a1a}-0.81\%$
test_select 0.2377ms 51.9839μs 19.2367 KOps/s 19.7978 KOps/s $\color{#d91a1a}-2.83\%$
test_unbind_speed 0.4893ms 0.3698ms 2.7041 KOps/s 2.6655 KOps/s $\color{#35bf28}+1.45\%$
test_unbind_speed_stack0 73.2037ms 4.5281ms 220.8417 Ops/s 209.9519 Ops/s $\textbf{\color{#35bf28}+5.19\%}$
test_unbind_speed_stack1 2.7106μs 0.6245μs 1.6014 MOps/s 1.5762 MOps/s $\color{#35bf28}+1.60\%$
test_split 66.3130ms 1.8030ms 554.6161 Ops/s 562.8290 Ops/s $\color{#d91a1a}-1.46\%$
test_chunk 60.3435ms 1.7362ms 575.9601 Ops/s 567.3770 Ops/s $\color{#35bf28}+1.51\%$
test_creation[device0] 0.5829ms 0.2916ms 3.4293 KOps/s 3.3840 KOps/s $\color{#35bf28}+1.34\%$
test_creation_from_tensor 3.8100ms 0.3334ms 2.9995 KOps/s 3.0024 KOps/s $\color{#d91a1a}-0.10\%$
test_add_one[memmap_tensor0] 61.1762ms 34.5601μs 28.9351 KOps/s 37.8753 KOps/s $\textbf{\color{#d91a1a}-23.60\%}$
test_contiguous[memmap_tensor0] 22.8930μs 5.9456μs 168.1928 KOps/s 171.8392 KOps/s $\color{#d91a1a}-2.12\%$
test_stack[memmap_tensor0] 0.1280ms 19.6560μs 50.8752 KOps/s 50.7886 KOps/s $\color{#35bf28}+0.17\%$
test_memmaptd_index 0.4871ms 0.2073ms 4.8234 KOps/s 4.9132 KOps/s $\color{#d91a1a}-1.83\%$
test_memmaptd_index_astensor 0.5406ms 0.2609ms 3.8329 KOps/s 3.8224 KOps/s $\color{#35bf28}+0.27\%$
test_memmaptd_index_op 0.5849ms 0.4972ms 2.0112 KOps/s 1.9640 KOps/s $\color{#35bf28}+2.40\%$
test_reshape_pytree 64.8610μs 23.2738μs 42.9667 KOps/s 42.9610 KOps/s $\color{#35bf28}+0.01\%$
test_reshape_td 89.4970μs 32.3899μs 30.8738 KOps/s 29.7906 KOps/s $\color{#35bf28}+3.64\%$
test_view_pytree 77.5750μs 23.2931μs 42.9311 KOps/s 43.5693 KOps/s $\color{#d91a1a}-1.46\%$
test_view_td 28.7840μs 4.9332μs 202.7077 KOps/s 205.8932 KOps/s $\color{#d91a1a}-1.55\%$
test_unbind_pytree 99.7970μs 26.3279μs 37.9825 KOps/s 38.4212 KOps/s $\color{#d91a1a}-1.14\%$
test_unbind_td 0.1201ms 60.6599μs 16.4854 KOps/s 17.0580 KOps/s $\color{#d91a1a}-3.36\%$
test_split_pytree 55.7140μs 26.1863μs 38.1880 KOps/s 38.7786 KOps/s $\color{#d91a1a}-1.52\%$
test_split_td 92.9840μs 47.3135μs 21.1356 KOps/s 21.7737 KOps/s $\color{#d91a1a}-2.93\%$
test_add_pytree 0.1007ms 31.8190μs 31.4277 KOps/s 30.6729 KOps/s $\color{#35bf28}+2.46\%$
test_add_td 0.1355ms 45.9085μs 21.7825 KOps/s 21.6582 KOps/s $\color{#35bf28}+0.57\%$
test_distributed 30.4370μs 6.2305μs 160.5010 KOps/s 169.6269 KOps/s $\textbf{\color{#d91a1a}-5.38\%}$
test_tdmodule 0.1711ms 21.2654μs 47.0248 KOps/s 43.7354 KOps/s $\textbf{\color{#35bf28}+7.52\%}$
test_tdmodule_dispatch 0.1763ms 38.9501μs 25.6739 KOps/s 25.3235 KOps/s $\color{#35bf28}+1.38\%$
test_tdseq 0.1160ms 24.1674μs 41.3780 KOps/s 41.8363 KOps/s $\color{#d91a1a}-1.10\%$
test_tdseq_dispatch 0.1448ms 44.2630μs 22.5922 KOps/s 23.0120 KOps/s $\color{#d91a1a}-1.82\%$
test_instantiation_functorch 2.0155ms 1.3070ms 765.0878 Ops/s 766.1025 Ops/s $\color{#d91a1a}-0.13\%$
test_instantiation_td 1.6829ms 1.0255ms 975.1148 Ops/s 963.5388 Ops/s $\color{#35bf28}+1.20\%$
test_exec_functorch 0.2538ms 0.1575ms 6.3476 KOps/s 6.1321 KOps/s $\color{#35bf28}+3.52\%$
test_exec_functional_call 0.2407ms 0.1483ms 6.7443 KOps/s 6.5678 KOps/s $\color{#35bf28}+2.69\%$
test_exec_td 0.2931ms 0.1419ms 7.0496 KOps/s 6.7262 KOps/s $\color{#35bf28}+4.81\%$
test_exec_td_decorator 0.8732ms 0.1770ms 5.6487 KOps/s 4.9746 KOps/s $\textbf{\color{#35bf28}+13.55\%}$
test_vmap_mlp_speed[True-True] 1.0624ms 0.8881ms 1.1260 KOps/s 1.1001 KOps/s $\color{#35bf28}+2.36\%$
test_vmap_mlp_speed[True-False] 0.6072ms 0.4691ms 2.1319 KOps/s 2.1221 KOps/s $\color{#35bf28}+0.46\%$
test_vmap_mlp_speed[False-True] 1.0158ms 0.7760ms 1.2886 KOps/s 1.2554 KOps/s $\color{#35bf28}+2.64\%$
test_vmap_mlp_speed[False-False] 0.6150ms 0.3848ms 2.5989 KOps/s 2.5726 KOps/s $\color{#35bf28}+1.03\%$
test_vmap_mlp_speed_decorator[True-True] 2.3785ms 1.7866ms 559.7232 Ops/s 555.3001 Ops/s $\color{#35bf28}+0.80\%$
test_vmap_mlp_speed_decorator[True-False] 1.0612ms 0.5192ms 1.9259 KOps/s 1.9152 KOps/s $\color{#35bf28}+0.56\%$
test_vmap_mlp_speed_decorator[False-True] 2.0463ms 1.4894ms 671.3891 Ops/s 663.6186 Ops/s $\color{#35bf28}+1.17\%$
test_vmap_mlp_speed_decorator[False-False] 0.8320ms 0.4006ms 2.4963 KOps/s 2.4847 KOps/s $\color{#35bf28}+0.47\%$

Copy link

github-actions bot commented Nov 28, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 127. Improved: $\large\color{#35bf28}1$. Worsened: $\large\color{#d91a1a}10$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 0.5970ms 12.7698μs 78.3097 KOps/s 78.7771 KOps/s $\color{#d91a1a}-0.59\%$
test_plain_set_stack_nested 0.1383ms 0.1144ms 8.7410 KOps/s 8.3414 KOps/s $\color{#35bf28}+4.79\%$
test_plain_set_nested_inplace 31.7410μs 14.0911μs 70.9668 KOps/s 71.2213 KOps/s $\color{#d91a1a}-0.36\%$
test_plain_set_stack_nested_inplace 0.1774ms 0.1430ms 6.9925 KOps/s 6.9164 KOps/s $\color{#35bf28}+1.10\%$
test_items 17.7710μs 4.6719μs 214.0478 KOps/s 212.9830 KOps/s $\color{#35bf28}+0.50\%$
test_items_nested 0.3735ms 0.3388ms 2.9520 KOps/s 2.9570 KOps/s $\color{#d91a1a}-0.17\%$
test_items_nested_locked 0.3698ms 0.3411ms 2.9314 KOps/s 2.9158 KOps/s $\color{#35bf28}+0.53\%$
test_items_nested_leaf 0.2485ms 0.1988ms 5.0298 KOps/s 4.9610 KOps/s $\color{#35bf28}+1.39\%$
test_items_stack_nested 1.5524ms 1.4895ms 671.3648 Ops/s 678.8585 Ops/s $\color{#d91a1a}-1.10\%$
test_items_stack_nested_leaf 1.3782ms 1.3056ms 765.9350 Ops/s 762.9752 Ops/s $\color{#35bf28}+0.39\%$
test_items_stack_nested_locked 0.8728ms 0.8187ms 1.2214 KOps/s 1.2101 KOps/s $\color{#35bf28}+0.93\%$
test_keys 21.1900μs 4.5802μs 218.3297 KOps/s 218.8120 KOps/s $\color{#d91a1a}-0.22\%$
test_keys_nested 3.3296ms 90.9159μs 10.9992 KOps/s 11.0575 KOps/s $\color{#d91a1a}-0.53\%$
test_keys_nested_locked 0.1145ms 90.8565μs 11.0064 KOps/s 11.1392 KOps/s $\color{#d91a1a}-1.19\%$
test_keys_nested_leaf 41.7128ms 87.3951μs 11.4423 KOps/s 12.2223 KOps/s $\textbf{\color{#d91a1a}-6.38\%}$
test_keys_stack_nested 1.3879ms 1.2860ms 777.6140 Ops/s 765.2629 Ops/s $\color{#35bf28}+1.61\%$
test_keys_stack_nested_leaf 1.3142ms 1.2655ms 790.2071 Ops/s 770.8048 Ops/s $\color{#35bf28}+2.52\%$
test_keys_stack_nested_locked 0.6919ms 0.6257ms 1.5981 KOps/s 1.5916 KOps/s $\color{#35bf28}+0.41\%$
test_values 8.5333μs 1.8900μs 529.0989 KOps/s 520.2745 KOps/s $\color{#35bf28}+1.70\%$
test_values_nested 69.2710μs 43.3092μs 23.0898 KOps/s 23.1620 KOps/s $\color{#d91a1a}-0.31\%$
test_values_nested_locked 66.0300μs 45.2939μs 22.0780 KOps/s 21.8965 KOps/s $\color{#35bf28}+0.83\%$
test_values_nested_leaf 57.6410μs 37.3269μs 26.7903 KOps/s 26.6355 KOps/s $\color{#35bf28}+0.58\%$
test_values_stack_nested 1.1922ms 1.1451ms 873.3032 Ops/s 877.6354 Ops/s $\color{#d91a1a}-0.49\%$
test_values_stack_nested_leaf 1.2803ms 1.1840ms 844.6241 Ops/s 895.2949 Ops/s $\textbf{\color{#d91a1a}-5.66\%}$
test_values_stack_nested_locked 0.5510ms 0.4975ms 2.0102 KOps/s 1.9760 KOps/s $\color{#35bf28}+1.73\%$
test_membership 5.7960μs 0.9543μs 1.0479 MOps/s 1.0318 MOps/s $\color{#35bf28}+1.55\%$
test_membership_nested 18.7900μs 2.2314μs 448.1582 KOps/s 453.2557 KOps/s $\color{#d91a1a}-1.12\%$
test_membership_nested_leaf 31.4600μs 2.1336μs 468.7015 KOps/s 474.7868 KOps/s $\color{#d91a1a}-1.28\%$
test_membership_stacked_nested 40.5800μs 11.0756μs 90.2889 KOps/s 90.0381 KOps/s $\color{#35bf28}+0.28\%$
test_membership_stacked_nested_leaf 44.5410μs 10.9852μs 91.0315 KOps/s 91.1310 KOps/s $\color{#d91a1a}-0.11\%$
test_membership_nested_last 18.0010μs 4.6312μs 215.9287 KOps/s 216.8197 KOps/s $\color{#d91a1a}-0.41\%$
test_membership_nested_leaf_last 20.4910μs 4.6340μs 215.7971 KOps/s 217.3207 KOps/s $\color{#d91a1a}-0.70\%$
test_membership_stacked_nested_last 0.1772ms 0.1354ms 7.3863 KOps/s 7.3868 KOps/s $-0.01\%$
test_membership_stacked_nested_leaf_last 36.8600μs 12.7071μs 78.6960 KOps/s 78.5321 KOps/s $\color{#35bf28}+0.21\%$
test_nested_getleaf 27.3800μs 8.4308μs 118.6129 KOps/s 119.2614 KOps/s $\color{#d91a1a}-0.54\%$
test_nested_get 30.5010μs 7.9507μs 125.7751 KOps/s 125.8496 KOps/s $\color{#d91a1a}-0.06\%$
test_stacked_getleaf 0.6274ms 0.5698ms 1.7549 KOps/s 1.7726 KOps/s $\color{#d91a1a}-1.00\%$
test_stacked_get 0.6193ms 0.5331ms 1.8758 KOps/s 1.8615 KOps/s $\color{#35bf28}+0.76\%$
test_nested_getitemleaf 28.6300μs 8.4616μs 118.1810 KOps/s 118.3637 KOps/s $\color{#d91a1a}-0.15\%$
test_nested_getitem 28.8900μs 7.9963μs 125.0573 KOps/s 125.3783 KOps/s $\color{#d91a1a}-0.26\%$
test_stacked_getitemleaf 0.7557ms 0.5764ms 1.7350 KOps/s 1.7524 KOps/s $\color{#d91a1a}-0.99\%$
test_stacked_getitem 0.5877ms 0.5319ms 1.8801 KOps/s 1.8872 KOps/s $\color{#d91a1a}-0.37\%$
test_lock_nested 3.3890ms 0.5605ms 1.7840 KOps/s 1.7729 KOps/s $\color{#35bf28}+0.63\%$
test_lock_stack_nested 81.7906ms 7.2695ms 137.5608 Ops/s 135.5569 Ops/s $\color{#35bf28}+1.48\%$
test_unlock_nested 2.3514ms 0.4350ms 2.2990 KOps/s 2.2796 KOps/s $\color{#35bf28}+0.85\%$
test_unlock_stack_nested 67.4362ms 6.2858ms 159.0886 Ops/s 159.4763 Ops/s $\color{#d91a1a}-0.24\%$
test_flatten_speed 0.2250ms 0.1871ms 5.3437 KOps/s 5.3409 KOps/s $\color{#35bf28}+0.05\%$
test_unflatten_speed 0.3971ms 0.3642ms 2.7457 KOps/s 2.7308 KOps/s $\color{#35bf28}+0.54\%$
test_common_ops 1.1266ms 0.5979ms 1.6725 KOps/s 1.6453 KOps/s $\color{#35bf28}+1.66\%$
test_creation 14.3600μs 2.1217μs 471.3131 KOps/s 477.9226 KOps/s $\color{#d91a1a}-1.38\%$
test_creation_empty 36.4900μs 7.0943μs 140.9589 KOps/s 138.3931 KOps/s $\color{#35bf28}+1.85\%$
test_creation_nested_1 27.2500μs 9.5486μs 104.7274 KOps/s 104.6265 KOps/s $\color{#35bf28}+0.10\%$
test_creation_nested_2 29.9600μs 12.1762μs 82.1273 KOps/s 82.1269 KOps/s $+0.00\%$
test_clone 84.8710μs 14.7867μs 67.6283 KOps/s 70.2480 KOps/s $\color{#d91a1a}-3.73\%$
test_getitem[int] 34.0410μs 12.6463μs 79.0745 KOps/s 81.0499 KOps/s $\color{#d91a1a}-2.44\%$
test_getitem[slice_int] 51.1710μs 24.5388μs 40.7517 KOps/s 42.8886 KOps/s $\color{#d91a1a}-4.98\%$
test_getitem[range] 89.1810μs 44.4115μs 22.5167 KOps/s 24.8645 KOps/s $\textbf{\color{#d91a1a}-9.44\%}$
test_getitem[tuple] 40.8510μs 21.8383μs 45.7911 KOps/s 49.0900 KOps/s $\textbf{\color{#d91a1a}-6.72\%}$
test_getitem[list] 0.2009ms 39.9372μs 25.0393 KOps/s 26.9642 KOps/s $\textbf{\color{#d91a1a}-7.14\%}$
test_setitem_dim[int] 52.8710μs 28.1004μs 35.5866 KOps/s 37.7968 KOps/s $\textbf{\color{#d91a1a}-5.85\%}$
test_setitem_dim[slice_int] 81.8010μs 49.1376μs 20.3510 KOps/s 21.5410 KOps/s $\textbf{\color{#d91a1a}-5.52\%}$
test_setitem_dim[range] 97.8120μs 67.7512μs 14.7599 KOps/s 15.5957 KOps/s $\textbf{\color{#d91a1a}-5.36\%}$
test_setitem_dim[tuple] 61.0510μs 41.6833μs 23.9904 KOps/s 24.7914 KOps/s $\color{#d91a1a}-3.23\%$
test_setitem 92.7220μs 19.2251μs 52.0154 KOps/s 55.3849 KOps/s $\textbf{\color{#d91a1a}-6.08\%}$
test_set 85.7610μs 18.8829μs 52.9581 KOps/s 56.7778 KOps/s $\textbf{\color{#d91a1a}-6.73\%}$
test_set_shared 2.8503ms 0.1045ms 9.5661 KOps/s 8.8714 KOps/s $\textbf{\color{#35bf28}+7.83\%}$
test_update 84.9010μs 18.4889μs 54.0864 KOps/s 52.9114 KOps/s $\color{#35bf28}+2.22\%$
test_update_nested 76.9710μs 24.8099μs 40.3065 KOps/s 38.9386 KOps/s $\color{#35bf28}+3.51\%$
test_set_nested 83.5510μs 18.6695μs 53.5633 KOps/s 52.4032 KOps/s $\color{#35bf28}+2.21\%$
test_set_nested_new 88.7220μs 23.0889μs 43.3109 KOps/s 42.8503 KOps/s $\color{#35bf28}+1.07\%$
test_select 97.5710μs 44.6531μs 22.3949 KOps/s 21.4309 KOps/s $\color{#35bf28}+4.50\%$
test_to 74.8410μs 53.0663μs 18.8443 KOps/s 18.4198 KOps/s $\color{#35bf28}+2.30\%$
test_to_nonblocking 62.8310μs 34.9748μs 28.5920 KOps/s 27.3588 KOps/s $\color{#35bf28}+4.51\%$
test_unbind_speed 0.3955ms 0.3648ms 2.7409 KOps/s 2.7367 KOps/s $\color{#35bf28}+0.16\%$
test_unbind_speed_stack0 62.5744ms 4.3389ms 230.4716 Ops/s 229.4911 Ops/s $\color{#35bf28}+0.43\%$
test_unbind_speed_stack1 1.3600μs 0.5376μs 1.8600 MOps/s 1.8795 MOps/s $\color{#d91a1a}-1.04\%$
test_split 54.0074ms 1.8056ms 553.8354 Ops/s 557.3246 Ops/s $\color{#d91a1a}-0.63\%$
test_chunk 53.8213ms 1.7987ms 555.9622 Ops/s 566.4871 Ops/s $\color{#d91a1a}-1.86\%$
test_creation[device0] 0.5344ms 0.3092ms 3.2341 KOps/s 3.2435 KOps/s $\color{#d91a1a}-0.29\%$
test_creation[device1] 0.8631ms 0.3119ms 3.2062 KOps/s 3.2156 KOps/s $\color{#d91a1a}-0.29\%$
test_creation_from_tensor 0.6190ms 0.3353ms 2.9821 KOps/s 2.9745 KOps/s $\color{#35bf28}+0.25\%$
test_add_one[memmap_tensor0] 62.8120μs 23.0823μs 43.3232 KOps/s 42.2368 KOps/s $\color{#35bf28}+2.57\%$
test_add_one[memmap_tensor1] 0.2297ms 73.8011μs 13.5499 KOps/s 13.3125 KOps/s $\color{#35bf28}+1.78\%$
test_contiguous[memmap_tensor0] 25.8810μs 5.7341μs 174.3946 KOps/s 174.9146 KOps/s $\color{#d91a1a}-0.30\%$
test_contiguous[memmap_tensor1] 45.5410μs 21.6607μs 46.1666 KOps/s 44.8580 KOps/s $\color{#35bf28}+2.92\%$
test_stack[memmap_tensor0] 48.2410μs 19.2425μs 51.9684 KOps/s 51.8817 KOps/s $\color{#35bf28}+0.17\%$
test_stack[memmap_tensor1] 0.1500ms 73.3730μs 13.6290 KOps/s 13.1463 KOps/s $\color{#35bf28}+3.67\%$
test_memmaptd_index 0.2754ms 0.2329ms 4.2934 KOps/s 4.1395 KOps/s $\color{#35bf28}+3.72\%$
test_memmaptd_index_astensor 0.3544ms 0.2904ms 3.4441 KOps/s 3.3665 KOps/s $\color{#35bf28}+2.30\%$
test_memmaptd_index_op 0.6260ms 0.5534ms 1.8071 KOps/s 1.7698 KOps/s $\color{#35bf28}+2.10\%$
test_reshape_pytree 37.7300μs 21.0902μs 47.4155 KOps/s 47.7509 KOps/s $\color{#d91a1a}-0.70\%$
test_reshape_td 56.5510μs 30.6206μs 32.6577 KOps/s 33.3237 KOps/s $\color{#d91a1a}-2.00\%$
test_view_pytree 40.8400μs 20.6699μs 48.3795 KOps/s 48.6545 KOps/s $\color{#d91a1a}-0.57\%$
test_view_td 26.4100μs 4.0219μs 248.6373 KOps/s 249.0094 KOps/s $\color{#d91a1a}-0.15\%$
test_unbind_pytree 42.0910μs 25.3568μs 39.4371 KOps/s 38.9757 KOps/s $\color{#35bf28}+1.18\%$
test_unbind_td 78.9010μs 56.6001μs 17.6678 KOps/s 17.4901 KOps/s $\color{#35bf28}+1.02\%$
test_split_pytree 45.7000μs 24.1700μs 41.3736 KOps/s 41.7030 KOps/s $\color{#d91a1a}-0.79\%$
test_split_td 0.5776ms 44.7041μs 22.3693 KOps/s 22.7399 KOps/s $\color{#d91a1a}-1.63\%$
test_add_pytree 54.4000μs 32.2551μs 31.0029 KOps/s 30.8904 KOps/s $\color{#35bf28}+0.36\%$
test_add_td 62.7710μs 44.4654μs 22.4894 KOps/s 21.9578 KOps/s $\color{#35bf28}+2.42\%$
test_distributed 18.6700μs 5.5762μs 179.3348 KOps/s 176.9380 KOps/s $\color{#35bf28}+1.35\%$
test_tdmodule 86.7810μs 16.4395μs 60.8292 KOps/s 58.8837 KOps/s $\color{#35bf28}+3.30\%$
test_tdmodule_dispatch 0.2228ms 33.0269μs 30.2784 KOps/s 29.9361 KOps/s $\color{#35bf28}+1.14\%$
test_tdseq 35.3810μs 19.9628μs 50.0932 KOps/s 50.1458 KOps/s $\color{#d91a1a}-0.10\%$
test_tdseq_dispatch 53.9500μs 36.0734μs 27.7213 KOps/s 27.1750 KOps/s $\color{#35bf28}+2.01\%$
test_instantiation_functorch 1.7570ms 1.6719ms 598.1126 Ops/s 601.1145 Ops/s $\color{#d91a1a}-0.50\%$
test_instantiation_td 1.6611ms 1.1668ms 857.0197 Ops/s 853.4764 Ops/s $\color{#35bf28}+0.42\%$
test_exec_functorch 0.1859ms 0.1549ms 6.4577 KOps/s 6.3800 KOps/s $\color{#35bf28}+1.22\%$
test_exec_functional_call 0.2116ms 0.1552ms 6.4447 KOps/s 6.4320 KOps/s $\color{#35bf28}+0.20\%$
test_exec_td 0.1782ms 0.1458ms 6.8607 KOps/s 6.6805 KOps/s $\color{#35bf28}+2.70\%$
test_exec_td_decorator 0.9100ms 0.1814ms 5.5121 KOps/s 5.3355 KOps/s $\color{#35bf28}+3.31\%$
test_vmap_mlp_speed[True-True] 1.2137ms 1.0723ms 932.6071 Ops/s 925.8279 Ops/s $\color{#35bf28}+0.73\%$
test_vmap_mlp_speed[True-False] 0.6869ms 0.6188ms 1.6161 KOps/s 1.6080 KOps/s $\color{#35bf28}+0.50\%$
test_vmap_mlp_speed[False-True] 1.0795ms 0.9810ms 1.0194 KOps/s 1.0134 KOps/s $\color{#35bf28}+0.59\%$
test_vmap_mlp_speed[False-False] 0.6449ms 0.5500ms 1.8183 KOps/s 1.8178 KOps/s $\color{#35bf28}+0.03\%$
test_vmap_mlp_speed_decorator[True-True] 3.1503ms 2.0616ms 485.0644 Ops/s 490.5235 Ops/s $\color{#d91a1a}-1.11\%$
test_vmap_mlp_speed_decorator[True-False] 1.1295ms 0.6612ms 1.5123 KOps/s 1.5047 KOps/s $\color{#35bf28}+0.51\%$
test_vmap_mlp_speed_decorator[False-True] 2.2543ms 1.7704ms 564.8397 Ops/s 552.8448 Ops/s $\color{#35bf28}+2.17\%$
test_vmap_mlp_speed_decorator[False-False] 1.0357ms 0.5650ms 1.7699 KOps/s 1.7715 KOps/s $\color{#d91a1a}-0.09\%$
test_vmap_transformer_speed[True-True] 13.2299ms 12.6719ms 78.9149 Ops/s 78.6779 Ops/s $\color{#35bf28}+0.30\%$
test_vmap_transformer_speed[True-False] 8.6597ms 8.3059ms 120.3962 Ops/s 119.0795 Ops/s $\color{#35bf28}+1.11\%$
test_vmap_transformer_speed[False-True] 12.6001ms 12.4449ms 80.3544 Ops/s 79.3973 Ops/s $\color{#35bf28}+1.21\%$
test_vmap_transformer_speed[False-False] 8.3872ms 8.2413ms 121.3395 Ops/s 120.0734 Ops/s $\color{#35bf28}+1.05\%$
test_vmap_transformer_speed_decorator[True-True] 0.1394s 69.0418ms 14.4840 Ops/s 14.2261 Ops/s $\color{#35bf28}+1.81\%$
test_vmap_transformer_speed_decorator[True-False] 22.5952ms 20.3228ms 49.2059 Ops/s 48.7752 Ops/s $\color{#35bf28}+0.88\%$
test_vmap_transformer_speed_decorator[False-True] 62.3430ms 59.5928ms 16.7806 Ops/s 16.9042 Ops/s $\color{#d91a1a}-0.73\%$
test_vmap_transformer_speed_decorator[False-False] 21.1743ms 19.9661ms 50.0848 Ops/s 49.8109 Ops/s $\color{#35bf28}+0.55\%$

@vmoens
Copy link
Contributor Author

vmoens commented Nov 28, 2023

Another script for optimizer state_dict:

import os
import tempfile
from typing import Any

import torch

from tensordict import tensorclass, TensorDict

model = torch.nn.Linear(3, 4)
optim = torch.optim.Adam(list(model.parameters()), lr=1e-2, amsgrad=True)
model(torch.randn(3)).sum().backward()
optim.step()
optim.zero_grad()

sd = optim.state_dict()
print(sd)


def print_directory_tree(path, indent=""):
    """
    Print the directory tree starting from the specified path.

    Parameters:
    - path (str): The path of the directory to print.
    - indent (str): The current indentation level for formatting.
    """
    if os.path.isdir(path):
        print(indent + os.path.basename(path) + "/")
        indent += "    "
        for item in os.listdir(path):
            print_directory_tree(os.path.join(path, item), indent)
    else:
        print(indent + os.path.basename(path))


@tensorclass
class ParamGroup:
    params: list
    lr: float
    betas: tuple
    eps: float
    weight_decay: float
    amsgrad: bool
    maximize: bool
    foreach: Any
    capturable: bool
    differentiable: bool
    fused: Any


tc = ParamGroup(**sd['param_groups'][0], batch_size=[])
print(tc)
td = TensorDict(
    {"state": {str(key): val for key, val in sd["state"].items()},
     "metadata": tc}, batch_size=[]
)
print(td)

with tempfile.TemporaryDirectory() as tempdir:
    td.memmap_(tempdir)
    print_directory_tree(tempdir)
    print("loaded", TensorDict.load_memmap(tempdir))

This produces a tensordict with the following structure

TensorDict(
    fields={
        metadata: ParamGroup(
            params=[0, 1],
            lr=0.01,
            betas=[0.9, 0.999],
            eps=1e-08,
            weight_decay=0,
            amsgrad=True,
            maximize=False,
            foreach=None,
            capturable=False,
            differentiable=False,
            fused=None,
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        state: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        exp_avg: MemoryMappedTensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=True),
                        exp_avg_sq: MemoryMappedTensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=True),
                        max_exp_avg_sq: MemoryMappedTensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=True),
                        step: MemoryMappedTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=True)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                1: TensorDict(
                    fields={
                        exp_avg: MemoryMappedTensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=True),
                        exp_avg_sq: MemoryMappedTensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=True),
                        max_exp_avg_sq: MemoryMappedTensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=True),
                        step: MemoryMappedTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=True)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

The file path looks like this

tmpbf33lwvd/
    state/
        0/
            max_exp_avg_sq.memmap
            exp_avg_sq.memmap
            step.memmap
            exp_avg.memmap
            meta.json
        1/
            max_exp_avg_sq.memmap
            exp_avg_sq.memmap
            step.memmap
            exp_avg.memmap
            meta.json
        meta.json
    metadata/
        _tensordict/
            meta.json
        meta.json
    meta.json

Here I represent the data in a way that mirrors state_dict but we could also imagine something different, eg

@tensorclass
class AdamState:
    params: TensorDict # contains the parameters - the `params` list isn't needed anymore because we have the tensordict keys
    lr: float
    betas: tuple
    eps: float
    weight_decay: float
    amsgrad: bool
    maximize: bool
    foreach: Any
    capturable: bool
    differentiable: bool
    fused: Any

@vmoens vmoens added the enhancement New feature or request label Nov 29, 2023
@vmoens
Copy link
Contributor Author

vmoens commented Nov 29, 2023

cc @shagunsodhani this may be relevant

The goal is to explore if tensorclass could be a good backend for optimizer checkpointing.
Advantages:

  • does not rely on pickle (which is "unsafe" when not saving only tensors)
  • can load partial state-dict: say you build 2 groups in your optimizer, you could load just one of the 2 (since we save each tensor separately)
  • The data structure on disk is self-explanatory and the hparams are accessible in json files (easy to parse from the outside)

@vmoens vmoens marked this pull request as ready for review December 5, 2023 08:35
@vmoens vmoens merged commit b130fc4 into main Dec 5, 2023
42 of 45 checks passed
@vmoens vmoens deleted the save-metadata-tc branch December 5, 2023 10:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants