Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Random speedups #728

Merged
merged 4 commits into from
Apr 18, 2024
Merged

[Performance] Random speedups #728

merged 4 commits into from
Apr 18, 2024

Conversation

albanD
Copy link
Contributor

@albanD albanD commented Apr 18, 2024

This is a proposal for speedup improvements for TensorDict struct by adding fast path for common use cases.
I only tested test/test_tensordict.py locally (the subset that works on main), there is most likely a few things I missed!

Runtime is tested with

from torch.nn import Transformer
from tensordict import TensorDict
import torch.utils.benchmark
t = Transformer()
td = TensorDict.from_module(t)

print(torch.utils.benchmark.Timer("td.flatten_keys()", globals=globals()).adaptive_autorange())
td.lock_()
print(torch.utils.benchmark.Timer("td.flatten_keys()", globals=globals()).adaptive_autorange())
print(torch.utils.benchmark.Timer("""td.flatten_keys(".")""", globals=globals()).adaptive_autorange())

Result before:

<torch.utils.benchmark.utils.common.Measurement object at 0x7f5c966e6a90>
td.flatten_keys()
  Median: 477.80 us
  IQR:    2.82 us (476.36 to 479.17)
  4 measurements, 1000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5c95c67910>
td.flatten_keys()
  Median: 524.48 ns
  IQR:    3.03 ns (523.41 to 526.44)
  4 measurements, 100000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5c95c677d0>
td.flatten_keys(".")
  Median: 727.95 ns
  IQR:    8.25 ns (724.57 to 732.82)
  4 measurements, 100000 runs per measurement, 1 thread

Result after:

<torch.utils.benchmark.utils.common.Measurement object at 0x7f1811c61250>
td.flatten_keys()
  Median: 226.33 us
  IQR:    0.10 us (226.25 to 226.34)
  4 measurements, 1000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f1810d1b1d0>
td.flatten_keys()
  Median: 159.65 ns
  IQR:    0.13 ns (159.57 to 159.70)
  4 measurements, 1000000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f1810d1aa10>
td.flatten_keys(".")
  Median: 216.46 ns
  IQR:    0.56 ns (216.21 to 216.77)
  4 measurements, 1000000 runs per measurement, 1 thread

albanD added 3 commits April 17, 2024 18:05
Most cached functions take no argument. And flatten_keys is
mostly used with either no arg or a single string argument.
Avoid one iteration over all the elements by creating the result in one go.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 18, 2024
@vmoens
Copy link
Contributor

vmoens commented Apr 18, 2024

For some reason the benchmarks don't appear in this PR but you can check them here
they look pretty good! Thanks so much!

@vmoens vmoens changed the title Random speedups [Performance] Random speedups Apr 18, 2024
@vmoens vmoens merged commit 45f0853 into pytorch:main Apr 18, 2024
44 of 48 checks passed
@albanD albanD deleted the random_speedups branch April 18, 2024 13:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants