Skip to content

Commit

Permalink
training acceleration via runtime semi-structured sparsity (#184)
Browse files Browse the repository at this point in the history
This PR adds in support for training acceleration, using runtime semi-structured sparsity kernels, which landed in core earlier: pytorch/pytorch#122350

This collects the necessary autograd functions, to support training and packages it up in a replacement `nn.Linear` modules, `SemiSparseLinear`, as well as a user API to swap out modules, `swap_linear_with_semi_sparse_linear_`. 

It also adds in some benchmarking code from xformers in order to measure the speedup of this module when applied to DINO shapes. 

We have a blog post coming out with more details about how this works. 

Testing:
```
python test/sparsity/test_fast_sparse_training.py 
```

Benchmarking:
```
python benchmarks/benchmark_semi_sparse.py 
```

For VIT-L MLP shapes we see the following results:
```
[------------------------------------------------ mlpfwbw -------------------------------------------------]
                                  |   act24   |   dense   |   w24    |  s24_inp_sparsify24  |  s24_inp_clone
1 threads: -------------------------------------------------------------------------------------------------
      f16 (44160,1024,4096,1024)  |  11881.0  |  11534.3  |  9204.7  |        255.1         |      125.8

Times are in microseconds (us).
```
  • Loading branch information
jcaip authored Jun 6, 2024
1 parent f3f2ea8 commit d97ae74
Show file tree
Hide file tree
Showing 7 changed files with 1,363 additions and 0 deletions.
129 changes: 129 additions & 0 deletions benchmarks/benchmark_semi_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple

import torch
import torch.nn.functional as F
from torch import nn
from xformers_benchmark_utils import DTYPE2STR, benchmark_main_helper2, product_dict

from torchao.sparsity.training import SemiSparseLinear
from torchao.sparsity.training.autograd import semi_structured_sparsify

min_run_time = 0.5
device = torch.device("cuda")

CASES = list(
product_dict(
B_in_hidden_out_ft=[
# DINO ViT-L: lg + sm crops (patch16)
(64 * 2 * (14 * 14 + 1) + 64 * 8 * (6 * 6 + 1), 1024, 1024 * 4, 1024),
],
dtype=[torch.half],
bias=[False],
)
)

class Mlp(nn.Module):
LINEAR_CLS = nn.Linear

def __init__(
self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool
) -> None:
B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft
super().__init__()
self.label = "mlp"
self.sub_label = (
f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}"
)
self.fc1 = self.LINEAR_CLS(in_ft, hid_ft, bias=bias)
self.act = nn.GELU()
self.fc2 = self.LINEAR_CLS(hid_ft, out_ft, bias=bias)
self.grad = torch.randn([B, out_ft], device="cuda", dtype=dtype)
self.input = torch.randn(
[B, in_ft], device="cuda", dtype=dtype, requires_grad=True
)
self.out = self.input
self.to("cuda").to(dtype)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x

def fw(self):
self.out = self.forward(self.input)

def bw(self):
self.out.backward(self.grad, retain_graph=True)


class MlpAct24(Mlp):
def fw(self):
x = self.input
x = self.fc1(x)
x = semi_structured_sparsify(x)
x = self.act(x)
x = self.fc2(x)
self.out = x



class MlpW24(Mlp):
LINEAR_CLS = SemiSparseLinear


class MicrobenchmarkBase:
def __init__(
self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool
) -> None:
B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft
super().__init__()
self.label = "mlp"
self.sub_label = (
f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}"
)
self.input = torch.randn(
[B, in_ft], device="cuda", dtype=dtype, requires_grad=True
)
self.input_colMajor = self.input.t().contiguous().t()
self.input_sp = semi_structured_sparsify(self.input)

def bw(self) -> None:
return None


class MicrobenchmarkSparsify24(MicrobenchmarkBase):
def fw(self) -> torch.Tensor:
semi_structured_sparsify(self.input)
return self.input


class MicrobenchmarkInputClone(MicrobenchmarkBase):
def fw(self) -> torch.Tensor:
self.input.clone()
return self.input


functions = {
"act24": MlpAct24,
"dense": Mlp,
"w24": MlpW24,
"s24_inp_sparsify24": MicrobenchmarkSparsify24,
"s24_inp_clone": MicrobenchmarkInputClone,
}
benchmark_main_helper2(
"sp24_fwbw",
fw=True,
bw=True,
cases=CASES,
functions=functions,
min_run_time=min_run_time,
)
Loading

0 comments on commit d97ae74

Please sign in to comment.