-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add noop detach for Nf4 tensor and enhance nf4 testing #40
Conversation
torchao/modules/nf4_linear.py
Outdated
from torchao.dtypes.nf4tensor import NF4Tensor, linear_nf4 | ||
|
||
|
||
class FrozenNF4Linear(nn.Linear): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just out of curiosity: Why can't this be done by overwriting the weight of an nn.Linear layer?
As in like here
def apply_sparse(model):
apply_fake_sparsity(model)
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the anwser is cause we don't define nn.linear, see comment above, but if it works then I agree we should
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For TorchTune use cases, we may not want to overwrite every single linear layer in the model with a frozen NF4 linear. Basically we wanna offer maximum flexibility to users, where even if in QLoRA currently every base linear is overwritten, they might wanna play around with this (for example making more granular tradeoff by only quantizing the qkv projections and not feed forwards).
Also, this won't work easily because our LoRA adapters are nn.Linears themselves - and we would not want to overwrite the LoRA adapeters.
cc @ebsmothers for thoughts on UX as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you link your existing LoraLinears? I imagine the best UX for users/torchtune would what christian says, and you if you say "I want to qloraify the q projections" you would have a util that swaps the LoraLinear non adapter weight for an nf4 tensor and that should be all that is needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg Thanks for the suggestion! Here is the LoRALinear
in torchtune: https://github.com/pytorch-labs/torchtune/blob/2fba15d18d35383f7b8ad4dac5369ca6646ae68e/torchtune/modules/peft/lora.py#L37
I'd imagine such a UX would be -
def swap_nf4():
for module in llama.modules():
if isinstance(module, LoRALinear):
# quantize to NF4
module.weight = NF4.from_tensor(module.weight)
Although, as opposed to module / parameter swapping, torchtune prefers to use a componentized builder approach where we build up models such as llama by plugging in the right nn.Module components depending on the config - nn.Linear for regular llama, LoRALinear for LoRA, and now NF4Linear for QLoRA. See an example of the builder pattern here: https://github.com/pytorch-labs/torchtune/blob/main/torchtune/models/llama2/_lora_llama2_builders.py#L135
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another point is that IMO NF4Linear
eliminates a lot of complexity around state_dict save/load. When load_state_dict
, I'm not sure whether there will be issues loading into a class that uses NF4 tensors. But if we have a specific NF4Linear that uses these NF4Tensros, we can attach load pre and post hooks to upcast / downcast the tensors appropriately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, here's what a load_state_dict
for QLoRA might look like:
def load_checkpoint_nf4():
# NOTE: this also enforces that ALL linear layers will always be quantized with QLoRA.
# That might not always be the case if users want to customize and for example only
# quantize some layers.
# Convert all NF4s to their original weight
for module in model.modules():
if isinstance(module, nn.Linear):
module.weight = module.weight.get_original_weight()
# would have to add support for bias as well
load_checkpoint(model)
# Now re-quantize
for module in model.modules():
if isinstance(module, nn.Linear):
module.weight = NF4Tensor.from_tensor(module.weight.data).to(module.weight.device)
This is of course assuming we quantize before loading state_dict, if we quantize after that could bring down the complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least for my mental model it does make sense to inherit from nn.Linear
here (as opposed to swapping out self.weight
from an nn.Linear
, though that is nice and simple). But FrozenNF4Linear
is basically a constrained version of nn.Linear
, right? The weight is a particular tensor subclass, and we also require that there be no gradient. So imo we should inherit as a way of being explicit about these constraints. That way I can look at an FrozenNF4Linear
and know that these conditions should hold, as opposed to having to try and figure them out across all my nn.Linear
s
test/modules/test_nf4_linear.py
Outdated
def test_frozen_nf4_linear(self): | ||
nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16) | ||
self.assertTrue(isinstance(nf4_linear.weight, NF4Tensor)) | ||
self.assertEqual(torch.bfloat16, nf4_linear.weight.get_original_weight().dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you keep both, won't that affect memory consumption? If the user wants both, they can decide to keep both around. Otherwise they could convert and re-assign an nn.Linear.weight Tensor like here
def apply_sparse(model):
apply_fake_sparsity(model)
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. Maybe a more standard API would be to support to(torch.bfloat16)
or such?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you keep both, won't that affect memory consumption
So IIUC we aren't keeping both here, but can verify via looking at the memory allocation after creating an instance of FrozenNF4Linear
.
get_original_weight
actually runs the dequantization and restores the original weight - maybe a bit of a misnomer since the name sorta implies its stored somewhere and is just accessed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah maybe "build_original_weight"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But can you actually restore the original weight? Hasn't some fidelity been lost after converting to nf4? Hence my suggestion to just overwrite to_dtype
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, it does seem we lose fidelity and we can't get the exact original weight as intuitively expected -
(@drisspg - just checking my understanding is correct).
So should we update NF4Tensor to get rid of get_original_weight
and just have a .to()
API? @drisspg, can this be done in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is Christian suggestions not mine lol, but regardless I think this PR markedly increases the testing coverage, and like you said if we want to do the full switch over to subclasses and do everything through torch dispatch I think it would make sense to do in a follow up PR, cc @cpuhrsch
torchao/modules/nf4_linear.py
Outdated
# types. | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
return linear_nf4(input=input, weight=self.weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised you can't just put this into the usual F.linear.
Maybe it's worth updating the ops table https://github.com/pytorch-labs/ao/blob/687f0f0eae8594f90afc447e0b5b52b524cb3fa6/torchao/dtypes/nf4tensor.py#L417-L439
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I first wrote this, I didn't make this a subclass because it didn't support compile, I think I left a comment somewhere that we should likely do this and make sure just make sure that we indeed get the right thing saved for backwards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand the 2 options correctly, its:
- Use torch_dispatch mechanism to "correctly" implement F.linear in this case, where "correctly" means saving the right tensors and avoiding saving extra tensors for the backward pass.
- Stick with the current autograd function implementation.
The reason I'm a bit of a proponent of sticking w/the autograd function implementation is because it's a bit more battle tested by @drisspg and the torch_dispatch support is a relatively new introduction. Could also switch over to this in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drisspg/transformer_nuggets#24
its friday so I there might be an "easy" way to fix this but will leave these as a future me thing
torchao/modules/nf4_linear.py
Outdated
del self.weight | ||
self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False) | ||
|
||
# TODO: likely need to handle state_dict save & load via hooks to properly manage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NF4Tensor might already support that as a Tensor subclass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome! Will probably test this out as part of follow up work. The main thing I wanna figure out is if we call load_state_dict
w/base model parameters in bf16, and try to load into NF4Tensor, do we crash, raise type mismatch issue, or just in time quantize the incoming weight and update the data. Will probably learn about this more when I begin the state_dict experimentation.
torchao/modules/nf4_linear.py
Outdated
if self.weight.dtype != torch.bfloat16: | ||
raise RuntimeError("FrozenNF4Linear is only supported with bf16 parameter currently") | ||
|
||
self.nf4_weight = NF4Tensor.from_tensor(self.weight.data).to(device).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there could also be use to like a to_nf4
factory function. Then it'd follow the pattern of torch.Tensor.to(<torch dtype>)
, but as a standalone function (which it'll have to be unless we somehow open up dtypes for open registration).
It's then quite similar to other memory/dtype/device oriented functions. In a nutshell, just because we now use nf4 instead of bfloat16, the Tensor's behavior etc. hasn't changed (of course individual values might have changed since nf4 has a different range etc.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense. Is this something we'd like to build in this PR or more as a longer-term follow up item? cc @drisspg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH I don't know if I follow this
test/modules/test_nf4_linear.py
Outdated
bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight) | ||
|
||
inp = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda') | ||
self.assertEqual(nf4_linear(inp).sum(), bnb_nf4_linear(inp).sum()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this is kinda a weird test I think reconstruction accuracy is better;
- compared to original: https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L45
- compared against bnb: "make sure we arent worse"
https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will def add reconstruction accuracy test. Curious why it's not as valuable to test exact parity w/BNB though?
1cbabc2
to
87740d0
Compare
torchao/modules/nf4_linear.py
Outdated
if self.weight.dtype != torch.bfloat16: | ||
raise RuntimeError("FrozenNF4Linear is only supported with bf16 parameter currently") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can't you just check self.dtype
first before even initializing the parent class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Add noop detach for Nf4 tensor and enhance nf4 testing
Note: things like
state_dict
save/load,.parameters()
returning expected data, etc are not addressed / in scope for this PR. We will need to ensure all of these work robustly as part of using this in torchtune as we'll need to load in base model parameters into this layer before quantizing (or quantize on the fly)Not sure if stuff is covered OOTB w/CI, but tested locally w/
python test/modules/test_nf4_linear.py -v