Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add noop detach for Nf4 tensor and enhance nf4 testing #40

Merged
merged 1 commit into from
Mar 5, 2024

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Mar 1, 2024

  • Adds preliminary torch dispatch support as prototyped by @drisspg and a no-op detach so that NF4Tensor can be registered as an nn.Parameter.
  • Enhances NF4tensor testing in torchao
  • Slight error msg enhancement in nf4tensor

Note: things like state_dict save/load, .parameters() returning expected data, etc are not addressed / in scope for this PR. We will need to ensure all of these work robustly as part of using this in torchtune as we'll need to load in base model parameters into this layer before quantizing (or quantize on the fly)

Not sure if stuff is covered OOTB w/CI, but tested locally w/ python test/modules/test_nf4_linear.py -v

@rohan-varma rohan-varma requested a review from drisspg March 1, 2024 01:13
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 1, 2024
from torchao.dtypes.nf4tensor import NF4Tensor, linear_nf4


class FrozenNF4Linear(nn.Linear):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity: Why can't this be done by overwriting the weight of an nn.Linear layer?

As in like here

https://github.com/pytorch-labs/segment-anything-fast/blob/387488bc4c7ab2ae311fb0632b34cab5cbfbab78/segment_anything_fast/sparse.py#L28-L32

def apply_sparse(model):
    apply_fake_sparsity(model)
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the anwser is cause we don't define nn.linear, see comment above, but if it works then I agree we should

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For TorchTune use cases, we may not want to overwrite every single linear layer in the model with a frozen NF4 linear. Basically we wanna offer maximum flexibility to users, where even if in QLoRA currently every base linear is overwritten, they might wanna play around with this (for example making more granular tradeoff by only quantizing the qkv projections and not feed forwards).

Also, this won't work easily because our LoRA adapters are nn.Linears themselves - and we would not want to overwrite the LoRA adapeters.

cc @ebsmothers for thoughts on UX as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you link your existing LoraLinears? I imagine the best UX for users/torchtune would what christian says, and you if you say "I want to qloraify the q projections" you would have a util that swaps the LoraLinear non adapter weight for an nf4 tensor and that should be all that is needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg Thanks for the suggestion! Here is the LoRALinear in torchtune: https://github.com/pytorch-labs/torchtune/blob/2fba15d18d35383f7b8ad4dac5369ca6646ae68e/torchtune/modules/peft/lora.py#L37

I'd imagine such a UX would be -

def swap_nf4():
  for module in llama.modules():
      if isinstance(module, LoRALinear):
          # quantize to NF4
          module.weight = NF4.from_tensor(module.weight)

Although, as opposed to module / parameter swapping, torchtune prefers to use a componentized builder approach where we build up models such as llama by plugging in the right nn.Module components depending on the config - nn.Linear for regular llama, LoRALinear for LoRA, and now NF4Linear for QLoRA. See an example of the builder pattern here: https://github.com/pytorch-labs/torchtune/blob/main/torchtune/models/llama2/_lora_llama2_builders.py#L135

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another point is that IMO NF4Linear eliminates a lot of complexity around state_dict save/load. When load_state_dict, I'm not sure whether there will be issues loading into a class that uses NF4 tensors. But if we have a specific NF4Linear that uses these NF4Tensros, we can attach load pre and post hooks to upcast / downcast the tensors appropriately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, here's what a load_state_dict for QLoRA might look like:

def load_checkpoint_nf4():
        # NOTE: this also enforces that ALL linear layers will always be quantized with QLoRA.
        # That might not always be the case if users want to customize and for example only
        # quantize some layers.
        # Convert all NF4s to their original weight
        for module in model.modules():
            if isinstance(module, nn.Linear):
                module.weight = module.weight.get_original_weight()
                # would have to add support for bias as well

        load_checkpoint(model)
        # Now re-quantize
        for module in model.modules():
            if isinstance(module, nn.Linear):
                module.weight = NF4Tensor.from_tensor(module.weight.data).to(module.weight.device)

This is of course assuming we quantize before loading state_dict, if we quantize after that could bring down the complexity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for my mental model it does make sense to inherit from nn.Linear here (as opposed to swapping out self.weight from an nn.Linear, though that is nice and simple). But FrozenNF4Linear is basically a constrained version of nn.Linear, right? The weight is a particular tensor subclass, and we also require that there be no gradient. So imo we should inherit as a way of being explicit about these constraints. That way I can look at an FrozenNF4Linear and know that these conditions should hold, as opposed to having to try and figure them out across all my nn.Linears

def test_frozen_nf4_linear(self):
nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16)
self.assertTrue(isinstance(nf4_linear.weight, NF4Tensor))
self.assertEqual(torch.bfloat16, nf4_linear.weight.get_original_weight().dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you keep both, won't that affect memory consumption? If the user wants both, they can decide to keep both around. Otherwise they could convert and re-assign an nn.Linear.weight Tensor like here

def apply_sparse(model):
    apply_fake_sparsity(model)
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Maybe a more standard API would be to support to(torch.bfloat16) or such?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you keep both, won't that affect memory consumption

So IIUC we aren't keeping both here, but can verify via looking at the memory allocation after creating an instance of FrozenNF4Linear.

get_original_weight actually runs the dequantization and restores the original weight - maybe a bit of a misnomer since the name sorta implies its stored somewhere and is just accessed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah maybe "build_original_weight"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But can you actually restore the original weight? Hasn't some fidelity been lost after converting to nf4? Hence my suggestion to just overwrite to_dtype here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, it does seem we lose fidelity and we can't get the exact original weight as intuitively expected -
image

(@drisspg - just checking my understanding is correct).

So should we update NF4Tensor to get rid of get_original_weight and just have a .to() API? @drisspg, can this be done in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is Christian suggestions not mine lol, but regardless I think this PR markedly increases the testing coverage, and like you said if we want to do the full switch over to subclasses and do everything through torch dispatch I think it would make sense to do in a follow up PR, cc @cpuhrsch

# types.

def forward(self, input: Tensor) -> Tensor:
return linear_nf4(input=input, weight=self.weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised you can't just put this into the usual F.linear.

Maybe it's worth updating the ops table https://github.com/pytorch-labs/ao/blob/687f0f0eae8594f90afc447e0b5b52b524cb3fa6/torchao/dtypes/nf4tensor.py#L417-L439

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I first wrote this, I didn't make this a subclass because it didn't support compile, I think I left a comment somewhere that we should likely do this and make sure just make sure that we indeed get the right thing saved for backwards

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand the 2 options correctly, its:

  1. Use torch_dispatch mechanism to "correctly" implement F.linear in this case, where "correctly" means saving the right tensors and avoiding saving extra tensors for the backward pass.
  2. Stick with the current autograd function implementation.

The reason I'm a bit of a proponent of sticking w/the autograd function implementation is because it's a bit more battle tested by @drisspg and the torch_dispatch support is a relatively new introduction. Could also switch over to this in the future.

WDYT @drisspg @cpuhrsch ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drisspg/transformer_nuggets#24

its friday so I there might be an "easy" way to fix this but will leave these as a future me thing

del self.weight
self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False)

# TODO: likely need to handle state_dict save & load via hooks to properly manage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NF4Tensor might already support that as a Tensor subclass

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! Will probably test this out as part of follow up work. The main thing I wanna figure out is if we call load_state_dict w/base model parameters in bf16, and try to load into NF4Tensor, do we crash, raise type mismatch issue, or just in time quantize the incoming weight and update the data. Will probably learn about this more when I begin the state_dict experimentation.

if self.weight.dtype != torch.bfloat16:
raise RuntimeError("FrozenNF4Linear is only supported with bf16 parameter currently")

self.nf4_weight = NF4Tensor.from_tensor(self.weight.data).to(device).to(dtype)
Copy link
Contributor

@cpuhrsch cpuhrsch Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there could also be use to like a to_nf4 factory function. Then it'd follow the pattern of torch.Tensor.to(<torch dtype>), but as a standalone function (which it'll have to be unless we somehow open up dtypes for open registration).

It's then quite similar to other memory/dtype/device oriented functions. In a nutshell, just because we now use nf4 instead of bfloat16, the Tensor's behavior etc. hasn't changed (of course individual values might have changed since nf4 has a different range etc.).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. Is this something we'd like to build in this PR or more as a longer-term follow up item? cc @drisspg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I don't know if I follow this

bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight)

inp = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda')
self.assertEqual(nf4_linear(inp).sum(), bnb_nf4_linear(inp).sum())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will def add reconstruction accuracy test. Curious why it's not as valuable to test exact parity w/BNB though?

@rohan-varma rohan-varma requested review from cpuhrsch and drisspg March 1, 2024 22:25
@rohan-varma rohan-varma force-pushed the nf4_linear branch 2 times, most recently from 1cbabc2 to 87740d0 Compare March 1, 2024 23:03
Comment on lines 23 to 24
if self.weight.dtype != torch.bfloat16:
raise RuntimeError("FrozenNF4Linear is only supported with bf16 parameter currently")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can't you just check self.dtype first before even initializing the parent class?

@rohan-varma rohan-varma changed the title Add Nf4Linear and tests Add noop detach for Nf4 tensor and enhance nf4 testing Mar 5, 2024
@rohan-varma rohan-varma requested a review from ebsmothers March 5, 2024 19:37
@rohan-varma rohan-varma merged commit c9b397d into main Mar 5, 2024
2 checks passed
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Add noop detach for Nf4 tensor and enhance nf4 testing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants