-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor tensor subclass API to also use paramterization #146
Conversation
a5b6dda
to
a9e5563
Compare
d0b9c23
to
25abb31
Compare
578b4f0
to
a906c53
Compare
2efcc92
to
c18e2f6
Compare
@@ -493,7 +493,7 @@ def quant_int8_dynamic_per_token_linear( | |||
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype | |||
) | |||
if bias is not None: | |||
mm_out += bias | |||
mm_out = mm_out + bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is some issue with this in AOT Inductor I think. cc @desertfire
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @cpuhrsch 's question is why rewriting "+=". I am not aware any AOTI restriction that needs this rewrite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@desertfire the error I'm getting with "+=" is this: https://gist.github.com/jerryzh168/d4ea2fb8138376cff903c38aaef8f5ef, is this expected?
).reshape(w.shape[0], -1) | ||
|
||
|
||
def pack_tinygemm_scales_and_zeros(scales, zeros): | ||
assert scales.shape == zeros.shape | ||
assert scales.dtype == torch.bfloat16 | ||
assert zeros.dtype == torch.bfloat16 | ||
assert scales.dtype == torch.bfloat16, f" got dtype: {scales.dtype}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this also show what dtype was expected? It seems like an opportunity for a dtype guard decorator or somesuch
def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
if dtype is not None and tensor_arg.dtype != dtype:
raise ValueError("Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
if size is not None and tensor_arg.size() != size:
raise ValueError("Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
guard_dtype_size(scales, "scales", torch.bfloat16, zeros.size())
guard_dtype_size(zeros, "zeros", torch.bfloat16)
See ValueError reference manual for why I chose ValueError here.
self.kwargs = kwargs | ||
|
||
def forward(self, int_data, q_scales): | ||
return from_qtensor_components_int8dyn(int_data, q_scales, *self.args, **self.kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use cls.__tensor_flatten__(*args)
for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean tensor unflatten? we can't use cls
in forward because of pytorch/pytorch#124735 right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you wrap do
def create_parameterization_module(cls):
class SubclassParameterization:
[...]
def forward(self, args):
cls.[...](args)
return SubclassParameterization
then cls is given as an argument to create_parameterization_module
and you return an instance of SubclassParameterization
where cls
is that argument. Essentially a module factory function.
These methods also shouldn't be static.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this using cls
in forward? I tried this before, and with @torch._dynamo.allow_in_graph
for the constructor function and it fails because we can't use class variable in dynamo right now I think.
are you suggesting something like this: 25abb31#diff-bf4d50867e3d649de2d89146592bf47d2f258c4c19126c8acf0e120ee904b726R134 (but using cls
instead of hardcoding the class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly and using __tensor_unflatten__
instead of from_qtensor_components
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for reference using cls
in forward is not supported until pytorch/pytorch#123350 is landed, according to Brain
torchao/quantization/subclass.py
Outdated
return from_qtensor_components_int8dyn(int_data, q_scales, *self.args, **self.kwargs) | ||
|
||
def right_inverse(self, tensor_subclass_instance): | ||
return tensor_subclass_instance.int_data, tensor_subclass_instance.q_scales |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use return self.__tensor_flatten__
for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this works, thanks. I'll create a parent class to host init and right_inverse
torchao/quantization/quant_api.py
Outdated
if enable_parametrization: | ||
lin.weight = torch.nn.Parameter(cls.from_float(lin.weight), requires_grad=False) | ||
_, args = lin.weight.__tensor_flatten__() | ||
parametrize.register_parametrization(lin, "weight", getattr(cls, constructor)(cls, *args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question - why do we want to enable this parameterization support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for supporting exporting the tensor subclass model, needed by aot_compile and also torch.export.export
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor subclasses don't work with AOTI
torchao/quantization/subclass.py
Outdated
**kwargs, | ||
) | ||
|
||
class ConstructTensorSubclassInt8Dyn(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this made generic for all tensor subclasses?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't do it now because of pytorch/pytorch#124735, should be able to do it after this is fixed
d7ba6af
to
a50fea5
Compare
test/integration/test_integration.py
Outdated
def wrapper(*args, **kwargs): | ||
if args[2] == "cuda" and not torch.cuda.is_available(): | ||
assert len(args) >= 3, f"Not enough args. Expected more than or equal to 3, but got {len(args)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw @cpuhrsch we need to use checks + skip test here I think, otherwise this test would fail:
FAIL: test_aoti (main.TestAOTI)
@@ -141,11 +149,11 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None): | |||
) | |||
|
|||
_replace_with_custom_fn_if_matches_filter( | |||
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight), filter_fn | |||
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect we use parametrization only for AOTI? As some kind of "pre-processing" there.
Especially given that my understanding of the long term plan is that AOTI will do this pre-processing themselves and we wll be able to remove it from there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need this for torch.export (used by executorch), I'll add a test in next PR, also we want to have a consistent code path for all backends/runtimes I think. is there any problems with enabling this for all use cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD do you think that long term we want export to do the pre-processing?
I think if that's the case, then we might just want to figure out that story now (it might be less work than getting dynamo to handle parametrizations).
The main contentious bit is probably just where this pre-processing should live. One possible answer is that it should happen transparently as part of torch.export.export()
: automatically search the created state dict for subclasses and flatten them (although this might be a problem if the user expects the state dict of the ExportedProgram
to alias the original model's state dict)
45543c0
to
aff0c5b
Compare
Summary: Also added tests for tensor subclass api + AOTI compilation Test Plan: python test/integration/test_integration.py -k test_aoti Reviewers: Subscribers: Tasks: Tags:
Summary: Also added tests for tensor subclass api + AOTI compilation Test Plan: python test/integration/test_integration.py -k test_aoti Reviewers: Subscribers: Tasks: Tags:
* tiktoken integration, part 1 * update tests
Summary:
Also added tests for tensor subclass api + AOTI compilation
Test Plan:
python test/integration/test_integration.py -k test_aoti
Two issues right now:
Reviewers:
Subscribers:
Tasks:
Tags: