Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DDPWrapper #2479

Open
wants to merge 28 commits into
base: juliagmt/test
Choose a base branch
from
Open

Conversation

juliagmt-google
Copy link
Collaborator

Benchmark improvements: #2468

anijain2305 and others added 4 commits September 26, 2024 00:50
Summary:
This reverts commit 7743149b2be4a9eba7e0997ccdc6abe552bec266.

Reverts
* pytorch/pytorch#135503
* pytorch/pytorch#135502
* pytorch/pytorch#135422

This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation.

```
import torch
from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.virtualized import ops
from torch.nn.attention.flex_attention import create_block_mask

torch.set_default_device('cuda')

flex_attention = torch.compile(flex_attention, dynamic=False)

prefix_lengths = torch.arange(8)
def prefix_lm(b, h, q, kv):
    return prefix_lengths[b] >= kv

mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
```

X-link: pytorch/pytorch#136590
Approved by: https://github.com/Chillee

Reviewed By: atalman

Differential Revision: D63431470

Pulled By: anijain2305

fbshipit-source-id: 60915b30336121b845af71f423582c22a6c65c3f
Summary: Add new metric `--metric nsys` to collect nsys trace.

Reviewed By: htyu

Differential Revision: D63274918

fbshipit-source-id: 0536310df6290ea5f5a02d85cc0ad6d342d45dbd
Summary:
pytorch#2458

Pull Request resolved: pytorch#2459

Reviewed By: xuzhao9

Differential Revision: D63476542

Pulled By: kit1980

fbshipit-source-id: 01e9db9cb03d34e82a773897417df2ccda410634
Summary: Pull Request resolved: pytorch#2473

Reviewed By: xuzhao9

Differential Revision: D63543625

Pulled By: bertmaher

fbshipit-source-id: 1693e15875544bda0f5f6c69daa5597fffd80509
Copy link
Collaborator Author

@juliagmt-google juliagmt-google left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test

Summary: Pull Request resolved: pytorch#2475

Reviewed By: htyu

Differential Revision: D63653081

Pulled By: xuzhao9

fbshipit-source-id: 8d840986779b6124cbccc2425c24e2b892d55ce4
Summary: We had the imports wrong for the internal port.

Reviewed By: xuzhao9, adamomainz

Differential Revision: D63643617

fbshipit-source-id: 04a49d419fede71d2681dedbfb55112a67cb4d55
Summary:
We have an old triton internally that doesn't have the cublasLt
bindings

Reviewed By: adamomainz

Differential Revision: D63643619

fbshipit-source-id: 39aece74b52f7747fe2100d7bb905bad49ba1fa0
Summary:
X-link: facebookresearch/FBGEMM#301

X-link: pytorch/FBGEMM#3202

Printing warnings to stdout mucks up the output of various tools/benchmarks

Reviewed By: xuzhao9, htyu

Differential Revision: D63643615

fbshipit-source-id: 1f34508a7fd36f5aa421e11bddd5ce77fc13038a
Summary: FBGEMM has changed how it declares its Cutlass-based blockwise gemm.

Reviewed By: htyu, sijiac, adamomainz

Differential Revision: D63643618

fbshipit-source-id: e46e3bbd2e07be0653f7c7fa6bd080b6c8db171e
Summary:
We have a big list of interesting shapes for blockwise/rowwise scaled
gemm.  A lot of these are variants of llama.  We might want to use them for
gemm and fp8_gemm (unscaled) as well, but for now just do blockwise/rowwise

Reviewed By: xuzhao9, adamomainz

Differential Revision: D63643616

fbshipit-source-id: 328961fe8c91e66428fcd1e5b72c89813f58a5a3
Summary:
We were only benchmarking `row-major x row-major` gemms (also called
`TT` or `transpose-transpose`, because FORTRAN), which is actually not the
common case; `nn.Linear` will use column-major layouts for weights, which means
`TN` is actually much more common.

Reviewed By: adamomainz

Differential Revision: D63714661

fbshipit-source-id: 735c25c59ddeb6596afd9b19f463af92036a830b
Summary: Pull Request resolved: pytorch#2483

Reviewed By: karthik-man

Differential Revision: D63726031

fbshipit-source-id: dc410e503f918d83362fb38005ac4a6db5dc1e68
FindHao and others added 5 commits October 4, 2024 14:07
Summary:
Allow users benchmark multiple ops in a single run. The ops can be split by commas, `--op fp8_gemm,addmm`

Example output:
```
% python run_benchmark.py triton --op fp8_gemm,addmm --num-inputs 1
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.12s/it]
             x_val    torch_fp8_gemm-gbps    torch_fp8_gemm-gbps    torch_fp8_gemm-latency    torch_fp8_gemm-tflops    triton_fp8_gemm-gbps    triton_fp8_gemm-gbps    triton_fp8_gemm-latency    triton_fp8_gemm-tflops
------------------  ---------------------  ---------------------  ------------------------  -----------------------  ----------------------  ----------------------  -------------------------  ------------------------
(1024, 1024, 1024)                462.202                462.202                0.00907462                  236.647                  630.43                  630.43                 0.00665309                    322.78
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.90s/it]
         (M, N, K)    aten_addmm-best_config    aten_addmm-gbps    aten_addmm-tflops                                                                                       triton_addmm-best_config    triton_addmm-gbps    triton_addmm-tflops    pt2_triton_matmul-best_config    pt2_triton_matmul-gbps    pt2_triton_matmul-tflops
------------------  ------------------------  -----------------  -------------------  -------------------------------------------------------------------------------------------------------------  -------------------  ---------------------  -------------------------------  ------------------------  --------------------------
(20120, 512, 1536)                                      818.112              247.544  {'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8, 'num_warps': 8, 'num_ctas': 1, 'num_stages': 3}              911.569                275.823                                                    889.125                     269.031
```

Pull Request resolved: pytorch#2490

Reviewed By: xuzhao9

Differential Revision: D63862548

Pulled By: FindHao

fbshipit-source-id: 9d4afa6051d4191bc2e3288f59e2820627647b91
Summary:
As discussed in pytorch/pytorch#136168, I'm going to migrate implementations of operator benchmarking. This PR adds different implementations for FusedLinearCrossEntropy as a starting example.

Execution command:
```
python run_benchmark.py triton --op FusedLinearCrossEntropy
```
Example output:
```
x_val    LMHeadCE-latency    LigerLMHeadCE-latency    inductor_fused_linear_cross_entropy-latency
-------  ------------------  -----------------------  ---------------------------------------------
      0             98.0041                  389.87                                         95.0412
      1            196.12                    652.619                                       193.219
      2            417.242                  1248.75                                        416.725
      3            824.906                  2356.25                                        809.56
```

Pull Request resolved: pytorch#2485

Reviewed By: xuzhao9

Differential Revision: D63859871

Pulled By: FindHao

fbshipit-source-id: 4b73a2144702c1f8f3ae5ed15e76112d03f12b87
…orch#2489)

Summary: Pull Request resolved: pytorch#2489

Reviewed By: xuzhao9

Differential Revision: D63898689

Pulled By: atalman

fbshipit-source-id: 3cd430911aadd5972f1393e3548ef7d52b93b661
Summary:
Remove nvidia-cuda-nvcc-cu12 as not required. Install time.

Pull Request resolved: pytorch#2493

Reviewed By: xuzhao9

Differential Revision: D63987509

Pulled By: atalman

fbshipit-source-id: 07298ddb569da7f7c3fe22d73da72a4ceab256f5
Summary:
Add a PR CI on Tritonbench that installs the latest Triton nightly package

Pull Request resolved: pytorch#2494

Reviewed By: chenyang78

Differential Revision: D63998525

Pulled By: xuzhao9

fbshipit-source-id: a26633de040bdf324e9ae5c9b130ec1a58dfd409
jamesjwu and others added 6 commits October 8, 2024 10:54
Summary:
X-link: pytorch/pytorch#137431

Log the current compilation id for all relevant samples for these two tables, so we can have a 1:1 analog with dynamo_compile.
ghstack-source-id: 246618821
exported-using-ghexport

Reviewed By: oulgen

Differential Revision: D63900826

fbshipit-source-id: 3f2896287777c94344960e7cad131f71aaf0210f
Summary:
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)

Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call

resume fn structure:
1. enter context
2. jump
...
3. exit context

The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).

So for torch function modes the structure of our output code is this:

1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function

Then our resume fn looks like this:

1. no-op enter torch function mode
2. jump
3.  exit tf mode

To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).

Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444

X-link: pytorch/pytorch#137114
Approved by: https://github.com/yanboliang

Reviewed By: jovianjaison

Differential Revision: D64088005

Pulled By: mlazos

fbshipit-source-id: 156b9bf38a535933f8dd966ee96ed3099d7b4be2
Summary:
Approved by: https://github.com/anijain2305
ghstack dependencies: #134732, #133137, #135443, #135444, #135422

X-link: pytorch/pytorch#137115
Approved by: https://github.com/yanboliang
ghstack dependencies: #137114

Reviewed By: jovianjaison

Differential Revision: D64088016

Pulled By: mlazos

fbshipit-source-id: 53efb5a6e689d4fb6112a6462851ee7e81b28c24
…s (#137119)

Summary:
X-link: pytorch/pytorch#137119
Approved by: https://github.com/williamwen42, https://github.com/anijain2305
ghstack dependencies: #137114, #137115, #137116, #137117, #137120, #137227

Reviewed By: jovianjaison

Differential Revision: D64088048

Pulled By: mlazos

fbshipit-source-id: 34fe09f7fa6292d89a438b780852f00e042ec950
Summary:
adding new configs for servicelab + logging to scuba

Follow up diff coming up to add aggregates into logging (ie harmonic mean)

Reviewed By: xuzhao9

Differential Revision: D64126688

fbshipit-source-id: 0c3705e82071f1399cfc53ff496d130adf237b73
…#2482)

Summary:
pytorch#2468

Pull Request resolved: pytorch#2482

Reviewed By: xuzhao9

Differential Revision: D64139543

Pulled By: atalman

fbshipit-source-id: 2d030c66d856387b6a2451b26c89fd40e79e0e53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.