Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental interface for torch ops #189

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

alihassanijr
Copy link
Member

@alihassanijr alihassanijr commented Dec 10, 2024

See #184

Only supports forward pass for now, due to current limitations of registering custom ops with torch compared to autograd functions. Some of those limitations are:

  • No stable interface for supporting autocasting to fp16/bf16,

    • Gradient scaling doesn't seem to be supported either, leading to training instability.
  • Ops cannot indicate that they expect contiguous operands, and need to call .contiguous() within, and this incurs additional tensor copy costs, and brings down throughput (in some cases it's hard to even tell the difference between compiled and eager.)

TODOs:

  • Unit tests? (unsure how yet)
  • Confirm graph doesn't break -- for some reason, when running with torch.no_grad, compiled graph isn't dumped to file with TORCH_COMPILE_DEBUG=1.

See SHI-Labs#184

Only supports forward pass for now, due to current limitations of
registering custom ops with torch compared to autograd functions. Some
of those limitations are:

* No stable interface for supporting autocasting to fp16/bf16,
  * Gradient scaling doesn't seem to be supported either, leading to
    training instability.

* Ops cannot indicate that they expect contiguous operands, and need to
  call `.contiguous()` within, and this incurs additional tensor copy
  costs, and brings down throughput (in some cases it's hard to even
  tell the difference between compiled and eager.)
@alihassanijr alihassanijr force-pushed the torch-ops-experimental branch from 18ecf32 to 5e96411 Compare December 10, 2024 20:08
@alihassanijr
Copy link
Member Author

alihassanijr commented Dec 10, 2024

@Birch-san could you check and see if the changes here resolve what you need when you get a chance?

Note: I kind of assumed you're using NATTEN ops (natten.functional) directly instead of the modules (natten.NeighborhoodAttention*D), so for that, all you'd need to do is import na*d from natten.experimental instead. You can also directly import the custom mappings: from natten.experimental import custom_mapping.

I'm still thinking about how to unit test the new FLOP counter and torch compile, but I have already verified both work manually (fvcore reports exactly 0.5 of PyTorch, because I imagine it's set up to report MACs, not FLOPs.)

@Birch-san
Copy link

Birch-san commented Dec 30, 2024

hi @alihassanijr, thanks for implementing this and sorry for the delay.

I've tried now invoking natten.experimental.na2d inside a FlopCounterMode context.

dispatch works (FlopCounterMode sees a func_packet whose _qualified_op_name is 'natten::na2d_forward_op'), but no FLOPs are counted because a flop counter isn't registered.
a user could manually pass a custom_mapping function when constructing their FlopCounter, like in #184 (comment), but it should be possible for NATTEN to register a flop counter globally so the user doesn't have to.

I was able to fix this by declaring this decorated function before running my model under FlopCounterMode context:

def fna_generic_flops_(
    q: torch.Size,
    k: torch.Size,
    v: torch.Size,
    has_bias: bool,
    kernel_size: Sequence[int],
) -> int:
    batch_size, heads, dim = (
        q[0],
        q[-2],
        q[-1],
    )

    spatial_extent: Sequence[int] = q[1 : len(kernel_size) + 1]
    spatial_extent_int = math.prod(spatial_extent)
    kernel_size_int = math.prod(kernel_size)

    flops = batch_size * heads * spatial_extent_int * dim * kernel_size_int  # QK

    # NOTE: PyTorch doesn't count softmax flops in SDPA;
    # Reference:
    # https://github.com/pytorch/pytorch/blob/7ced49d2ccf219ec896810e6d988709c3a3a2d9a/torch/utils/flop_counter.py#L241-L256
    # flops += batch_size * heads * spatial_extent_int * kernel_size_int  # softmax
    flops += batch_size * heads * spatial_extent_int * dim * kernel_size_int  # AV

    if has_bias:
        flops += batch_size * heads * spatial_extent_int * kernel_size_int  # RPB
    return flops

@register_flop_formula(torch.ops.natten.na2d_forward_op)
def na2d_flop(
    query: torch.Size,
    key: torch.Size,
    value: torch.Size,
    bias: Optional[torch.Size],
    kernel_size_: Sequence[int],
    dilation_: Sequence[int],
    is_causal_: Sequence[bool],
    scale: float,
    q_tiler_: Sequence[int],
    kv_tiler_: Sequence[int],
    *args,
    out_shape=tuple[torch.Size, torch.Size],
    **kwargs,
) -> int:
    return fna_generic_flops_(query, key, value, bias is not None, kernel_size_)

Now FLOPs are counted successfully:

Encoder: 1024x1024px, batch size 1
Module                            FLOP    % Total
---------------------------  ---------  ---------
Global                       4338.875B    100.00%
 - aten.convolution          4294.835B     98.98%
 - aten.bmm                    25.770B      0.59%
 - natten.na2d_forward_op      18.270B      0.42%
 Conv2d                         0.002B      0.00%
  - aten.convolution            0.002B      0.00%
 Encoder                     4338.873B    100.00%
  - aten.convolution         4294.833B     98.98%
  - aten.bmm                   25.770B      0.59%
  - natten.na2d_forward_op     18.270B      0.42%
  Encoder.conv_in               7.248B      0.17%
   - aten.convolution           7.248B      0.17%
  Encoder.conv_out              1.208B      0.03%
   - aten.convolution           1.208B      0.03%
  Encoder.mid_block           353.278B      8.14%
   - aten.convolution         309.238B      7.13%
   - aten.bmm                  25.770B      0.59%
   - natten.na2d_forward_op    18.270B      0.42%
|   Bsz | Wid x Hei px   |   Megapx | FLOP/s        |   ms/iter |   iter/s |
|-------|----------------|----------|---------------|-----------|----------|
|     1 | 1024x1024      |     1.05 |  89.9 TFLOP/s |      48.3 |     20.7 |

Decoder: 1024x1024px, batch size 1
Module                            FLOP    % Total
---------------------------  ---------  ---------
Global                       9930.317B    100.00%
 - aten.convolution          9886.277B     99.56%
 - aten.bmm                    25.770B      0.26%
 - natten.na2d_forward_op      18.270B      0.18%
 Conv2d                         0.001B      0.00%
  - aten.convolution            0.001B      0.00%
 Decoder                     9930.317B    100.00%
  - aten.convolution         9886.277B     99.56%
  - aten.bmm                   25.770B      0.26%
  - natten.na2d_forward_op     18.270B      0.18%
  Decoder.conv_in               0.604B      0.01%
   - aten.convolution           0.604B      0.01%
  Decoder.conv_out              7.248B      0.07%
   - aten.convolution           7.248B      0.07%
  Decoder.mid_block           353.278B      3.56%
   - aten.convolution         309.238B      3.11%
   - aten.bmm                  25.770B      0.26%
   - natten.na2d_forward_op    18.270B      0.18%

|   Bsz | Wid x Hei px   |   Megapx | FLOP/s        |   ms/iter |   iter/s |
|-------|----------------|----------|---------------|-----------|----------|
|     1 | 1024x1024      |     1.05 |  83.4 TFLOP/s |     119.1 |      8.4 |

I note that this flop count algorithm is different to the one that I wrote by guesswork, in my original post #184 (comment).

my own FLOP count made this to be exactly 2x the amount, 36.541B operations.
maybe it's the difference between counting FLOs vs MACs?
I built mine on top of torch.utils.flop_counter.bmm_flop, which does indeed employ a * 2 term.

I think the torch counters return FLOs whereas your fna_generic_flops function returns MACs?
the only precedent I've seen where a paper returned MACs was openai/guided-diffusion / Diffusion Models Beat GANS on Image Synthesis, which counts FLOPs via THOP. I consider this esoteric. FLOP counting isn't really standardized, but I'd say the most useful approach would be to follow pytorch FlopCounterMode's conventions, so that they make sense in the context of all the other operations I'm measuring.

also, I think it's problematic that users cannot import flops.py if they don't have fvcore installed. this is what forced me to write my own fna_generic_flops_ (see above).
torch's FlopCounterMode gives you tensors sizes and all the args with which the operation was invoked, so I think a lot of the gymnastics that fna_generic_flops is doing to sleuth out kernel_size, aren't necessary in FlopCounterMode's conventions.

I did initially try to delegate flop counting to your fna_generic_flops; the inputs and outputs can be constructed like this:

@register_flop_formula(torch.ops.natten.na2d_forward_op)
def na2d_flop(
    query: torch.Size,
    key: torch.Size,
    value: torch.Size,
    bias: Optional[torch.Size],
    kernel_size_: Sequence[int],
    dilation_: Sequence[int],
    is_causal_: Sequence[bool],
    scale: float,
    q_tiler_: Sequence[int],
    kv_tiler_: Sequence[int],
    *args,
    out_shape=tuple[torch.Size, torch.Size],
    **kwargs,
) -> int:
    inputs: list[torch.Size] = [query, key, value, *[bias]*(bias is not None)]
    outputs: list[torch.Size] = list(out_shape)

    return fna_generic_flops(inputs, outputs)

but it failed on this assertion assert hasattr(inputs[0], "uses") and callable(inputs[0].uses), since FlopCounterMode doesn't give me tensors, only torch.Sizes.

anyway I recommend to expose flop counter functions that can be used without having fvcore installed, and which take Sizes as inputs rather than tensors, and I also recommend to register flop counters via @register_flop_formula so that the user doesn't have to.

@Birch-san
Copy link

Birch-san commented Dec 30, 2024

as for graph breaks:

I slapped a @torch.compile(fullgraph=True) on Attention#forward, which includes a call to na2d.

under natten.functional.na2d, I got:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py:700: UserWarning: Graph break due to unsupported builtin natten.libnatten.PyCapsule.na2d_forward. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.

whereas under natten.experimental.na2d: I got no such warning; the model executed successfully. there was no measurable difference in benchmark speed, which isn't particularly surprising and at least means nothing bad happened.

@Birch-san
Copy link

regarding FLOP counter unit test, how about:

import torch
from torch import no_grad
from torch._ops import OpOverloadPacket
from torch.utils.flop_counter import FlopCounterMode
from natten.experimental import na2d

bsz = 2
heads = 8
head_dim = 64
wid = 32
hei = 32
kernel_size = 7

dtype = torch.float16
device = torch.device('cuda')

gen = torch.Generator(device=device).manual_seed(42)
q = torch.randn(bsz, heads, hei, wid, head_dim, device=device, dtype=dtype, generator=gen)
k = torch.randn(bsz, heads, hei, wid, head_dim, device=device, dtype=dtype, generator=gen)
v = torch.randn(bsz, heads, hei, wid, head_dim, device=device, dtype=dtype, generator=gen)

counter = FlopCounterMode()
# we avoid FLOP counting under inference_mode() context because it's poorly-supported;
# even basics like matmuls dispatch different ops under inference_mode, for which no flop counter is registered
with counter, no_grad():
    na2d(q, k, v, kernel_size)
global_flops: dict[OpOverloadPacket, int] = counter.flop_counts['Global']
assert torch.ops.natten.na2d_forward_op in global_flops, "na2d FLOPs not counted"
na2d_flops: int = global_flops[torch.ops.natten.na2d_forward_op]
assert na2d_flops != 102760448, "na2d returned MACs instead of FLOs"
assert na2d_flops == 205520896, "na2d returned unexpected FLOP count"

@Birch-san
Copy link

Birch-san commented Dec 30, 2024

for compile unit test, I think all you need to do is attempt to invoke the model under fullgraph compilation and see that the program doesn't explode.
this test will fail if you switch it to using natten.functional.

import torch
from torch import Tensor, inference_mode
from torch.nn import Module
from torch._dynamo.exc import Unsupported
from einops import rearrange
from natten.experimental import na2d

class Attention(Module):
    def __init__(
        self,
        in_dim: int,
        heads: int,
        head_dim: int,
        kernel_size: int,
        device: torch.device = torch.device('cuda'),
        dtype: torch.dtype = torch.float16,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.kernel_size = kernel_size
        self.head_dim = head_dim
        self.qkv_proj = torch.nn.Linear(in_dim, 3*head_dim*heads, bias=False, **factory_kwargs)

    @torch.compile(fullgraph=True)
    def forward(self, x: Tensor):
        qkv = self.qkv_proj(x)
        q, k, v = rearrange(qkv, "... h w (proj heads head_dim) -> proj ... h w heads head_dim", proj=3, head_dim=self.head_dim)
        return na2d(q, k, v, self.kernel_size)

dtype = torch.float16
device = torch.device('cuda')

bsz = 2
inner_dim = 320
wid = 32
hei = 32

gen = torch.Generator(device=device).manual_seed(42)
x = torch.randn(bsz, hei, wid, inner_dim, device=device, dtype=dtype, generator=gen)

attn = Attention(
    in_dim=inner_dim,
    heads=5,
    head_dim=64,
    kernel_size=7,
    device=device,
    dtype=dtype,
)

try:
    with inference_mode():
        attn(x)
except Unsupported as e:
    if 'graph break' in e.msg.lower():
        raise AssertionError('Test failure')
    raise e
print('Test success')

@Birch-san
Copy link

Birch-san commented Dec 30, 2024

and if you're thinking of using torch FlopCounterMode yourself in order to benchmark NATTEN, beware this gotcha about how FlopCounterMode seems to make torch.compile fall back to eager mode in torch 2.5+:
pytorch/pytorch#140909
the workaround for now is "measure latency of the compiled model first, then count its FLOPs" (because its compiled perf will regress after being run inside a FlopCounterMode context)

@alihassanijr
Copy link
Member Author

@Birch-san thank you so much for the feedback.

I'll check the FLOPs vs MACs issue -- I remember checking this and finding out fvcore computed one and torch the other, and specifically adjusted everything according to that. I think I even used torch's bmm_flop within NATTEN's just like you did, so unsure what's happening there. I'll check it again today.

Thanks for the feedback on register_flop_formula and importing from natten.flops -- I'll try and work those in today as well.

The only thing that worries me is that given fvcore and torch report different metrics (flops v macs), I might have to rename a few things to make them less confusing. I'll have to add some documentation for that.

And thanks for the unit test idea, and verifying compilation isn't breaking the graph. I think we should be ready to merge this in soon.

@alihassanijr
Copy link
Member Author

alihassanijr commented Jan 2, 2025

Okay according to a dummy example I set up, I'm getting flops reported by torch and fvcore as follows:

Total flops [PyTorch]: 21726208
Total flops [fvcore]: 10863104

After checking flops.py, I can confirm we're not multiplying by 2 either, so I guess it's safe to say that torch reports FLOPs and fvcore reports MACs.

Back to the discrepancy you were observing, I also noticed that you also used torch's bmm_flop in #184 , which is no surprise, I started off of your code. So I'm unsure if there's something else that my flop counter is getting differently compared to yours in #184.

I'll keep looking to see if I find a difference, but could you also double check on your end if switching out the flop counters results in a difference?


UPDATE: sorry, I just caught a mistake I made in the new flop count. I'll push the fix soon.

+ Separate FLOPs doc
@alihassanijr
Copy link
Member Author

@Birch-san Turns out the "QK" part of the new FLOP counter was wrong. I didn't notice because my test case was using identical kernel size and feature map sizes.

I just pushed a commit doing what should've been done from the get go, which is a major refactor so that whether we use fvcore, torch, or manually count flops, it all ends up calling the same underlying API, as to prevent mistakes due to code duplication.

Now the FLOP counting is correct, experimental ops are using register_flop_formula so users don't have to manually add mappings, and there's more extensive documentation on the differences between the fvcore ad torch.

I'll work on adding those extra unit tests now, but already verified it's working as expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants