-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Named Symbol not found (torchchat #1298) #1110
Comments
Cuda version: This ran on Google colab. Detailed trace-back/repro: https://colab.research.google.com/drive/1PRneJBaS5TlJaIgc4Lwv2muiePp6T9Ss?usp=sharing
|
I believe this is because tinygemm does not support sm75 i.e. T4 |
We could throw a better error message |
I couldn't repro this on a fresh google colab t4 gpu. Might be something more environment specific in the linked notebook and is likely just an issue with needing specific cuda versions installed In particular please note that you can get torchao linked against a specific cuda version by installing it from the pytorch index https://github.com/pytorch/ao#installation otherwise installing from source is generally less finicky # -*- coding: utf-8 -*-
"""Untitled103.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1AgFr2Ofz4aEc3s_KFX-1LL4NaCpzCjv9
"""
# ! USE_CPP=0 pip install git+https://github.com/pytorch/ao.git@msaroufim/better-tinygemmwarning-for-google-colab --force-reinstall
! USE_CPP=0 pip install git+https://github.com/pytorch/ao.git --force-reinstall
! pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
# prompt: toy pytorch model with a single linear layer
import torch
import torch.nn as nn
class ToyModel(nn.Module):
def __init__(self, input_size, output_size):
super(ToyModel, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
# Example usage
input_size = 10
output_size = 1
model = ToyModel(input_size, output_size).cuda()
# Create some sample input data
input_data = torch.randn(1, input_size).cuda()
# Perform a forward pass
output = model(input_data)
print(output)
import torch
import torchao
from torchao.quantization.quant_api import (
quantize_,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_weight_only
)
quantize_(model, int4_weight_only())
model(input_data)
|
The error comes from PyTorch's I can reproduce with the following snippet import torch
x = torch.randint(0, 255, size=(1024, 1024), dtype=torch.uint8).cuda()
x = x.view(torch.int32) # I think 2.4 expects int32, 2.5 expect uint8
torch.ops.aten._convert_weight_to_int4pack(x, innerKTiles=8) @msaroufim Your example does not reproduce the error because the weight is too small. When I changed to the following, the error is reproduced # Example usage
input_size = 1024
output_size = 1024
model = ToyModel(input_size, output_size).cuda().bfloat16() # also require BF16
# Create some sample input data
input_data = torch.randn(1, input_size).cuda().bfloat16() |
This might just be the leading edge for bfloat16 doesn't work on T4 in PyTorch because it's not supported by HW? Directionally, would we look at adding this support, or... we just put a better error message? so, it's def bfloat causing this, as can be ascertained with this command:
Sadly, switching to float 16 does not work either:
Ditto for torch.float32, same error - would suggest that int4 linear quantization isn't available unless the hardware has support for BF16? (This would be much less of an issue if the bread and butter of Google colab were not T4 which is a super convenient place to allow uses to try quick experiments/ramp up....) |
pytorch/torchchat#1344 recognizes whether the target GPU has bfloat16, and avoids using it (for fast* alias dtypes, by using fp16), or issues an error (if bf16 is explicitly specified) |
tinygemm (the INT4 weight-only kernel for CUDA) only supports BF16 https://github.com/pytorch/pytorch/blob/86d7d39bffd3b7b099310fb351b2b36f99981d6f/aten/src/ATen/native/cuda/int4mm.cu Though in theory it should be possible to make it work with FP16, similar to #1147 |
yeah the meta-internal tinygemm version actually already supports fp16, we just need to upstream the changes. cc @yanboliang plans to do this |
Hi! I'm interested in contributing to this issue. |
At a minimum we'd want a better error message (for this kernel only....pytorch overall simply emulates bf16 elsewhere). It would be cool if we could enable this kernel to work with BF16. Pytorch emulates it elsewhere - basically if tinygemm supports fp32 computation, we can block as fp16 and compute as fp32 (bf16 is the most significant 16 bit of FP32, so it's easiy to convert to/fro.). Especially if we already accumulate in fp32 (which some GEMMs do by default to avoid rounding artifacts) This is mostly an issue because T4 is so widespread as google colab GPU accelerator. Otherwise I can't think of a platform that would be this long lasting and still be in broad use? |
Quantized model gets a CUDA error "Named symbol not found".
see pytorch/torchchat#1298
The text was updated successfully, but these errors were encountered: