Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize Model Size Code #364

Merged
merged 2 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,60 @@ def forward(self, x):
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))

class TestUtils(unittest.TestCase):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_get_model_size_autoquant(self, device, dtype):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")
m, k, n = 16, 128, 128
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to(device).to(dtype)
example_input = torch.randn(m, k, device=device, dtype=dtype)
size = torchao.utils.get_model_size_in_bytes(model)

from torchao.quantization.autoquant import (
AQWeightOnlyQuantizedLinearWeight2,
)
qtensor_class_list = (
AQWeightOnlyQuantizedLinearWeight2,

)

mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list)
mod(example_input)
size2 = torchao.utils.get_model_size_in_bytes(mod)
self.assertTrue(size2 < size)

@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
def test_get_model_size_aqt(self, api, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"{api} in {dtype} is not supported yet")
k, n = 1024, 1024
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to(device).to(dtype)
size = torchao.utils.get_model_size_in_bytes(model)
api(model)
size2 = torchao.utils.get_model_size_in_bytes(model)
if size2 >= size:
breakpoint()
self.assertTrue(size2 < size)




if __name__ == "__main__":
unittest.main()
18 changes: 2 additions & 16 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchao
import torch._dynamo.config
import torch._inductor.config
from torchao.utils import get_model_size_in_bytes

def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -143,21 +144,6 @@ def _load_model(checkpoint_path, device, precision):

return model.eval()

def _get_model_size(model):
model_size = 0
for name, child in model.named_children():
if not isinstance(child, torch.nn.Embedding):
for p in itertools.chain(child.parameters(), child.buffers()):
# handling for tensor subclasses
if isinstance(p, torchao.dtypes.aqt.AffineQuantizedTensor):
layout_tensor = p.layout_tensor
for attr_name in layout_tensor._tensor_flatten__()[0]:
sub_tensor = getattr(layout_tensor, attr_name)
model_size += sub_tensor.numel() * sub_tensor.element_size()
else:
model_size += p.numel() * p.element_size()
return model_size

B_INST, E_INST = "[INST]", "[/INST]"

def main(
Expand Down Expand Up @@ -226,7 +212,7 @@ def main(
interactive=False
)

model_size = _get_model_size(model) / 1e9
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

if compile:
global decode_one_token, prefill
Expand Down
34 changes: 26 additions & 8 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from math import gcd
from packaging import version
import torch.nn.utils.parametrize as parametrize
import itertools

__all__ = [
"benchmark_model",
Expand Down Expand Up @@ -82,14 +83,31 @@ def find_multiple(n: int, *args: Tuple[int]) -> int:
return n
return n + k - (n % k)

# https://discuss.pytorch.org/t/finding-model-size/130275
def get_model_size_in_bytes(model):
s = 0
for p in model.parameters():
s += p.nelement() * p.element_size()
for b in model.buffers():
s += b.nelement() * b.element_size()
return s
def get_model_size_in_bytes(model, ignore_embeddings=False):
"""
Returns the model size in bytes. The option to ignore embeddings
is useful for models with disproportionately large embeddings compared
to other model parameters that get quantized/sparsified.
"""
def flat_size(tensor):
if hasattr(tensor, "__tensor_flatten__"):
size = 0
# 0th element is a list of attributes that
# hold tensors
for attr_name in tensor.__tensor_flatten__()[0]:
sub_tensor = getattr(tensor, attr_name)
size += flat_size(sub_tensor)
return size
else:
return tensor.numel() * tensor.element_size()

model_size = 0
for name, child in model.named_children():
if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings):
for p in itertools.chain(child.parameters(recurse=False), child.buffers(recurse=False)):
model_size += flat_size(p)
model_size += get_model_size_in_bytes(child, ignore_embeddings)
return model_size

class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors):
Expand Down
Loading