Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Dev pruner DataParallel #2022

Merged
merged 7 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions examples/model_compress/MeanActivation_torch_cifar10.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner
from nni.compression.torch import ActivationMeanRankFilterPruner
from models.cifar10.vgg import VGG


Expand Down Expand Up @@ -40,6 +41,12 @@ def test(model, device, test_loader):


def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")

args = parser.parse_args()
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
Expand All @@ -63,14 +70,15 @@ def main():
model.to(device)

# Train the base VGG-16 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(args.epochs):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')

# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
Expand All @@ -88,8 +96,16 @@ def main():

# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
pruner = ActivationMeanRankFilterPruner(model, configure_list)
model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
else:
print("only detect 1 gpu, fall back")

model.to(device)
test(model, device, test_loader)
# top1 = 88.19%

Expand Down
19 changes: 13 additions & 6 deletions examples/model_compress/fpgm_torch_mnist.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner

class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
Expand All @@ -27,8 +28,14 @@ def _get_conv_weight_sparsity(self, conv_layer):
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters

def print_conv_filter_sparsity(self):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
if isinstance(self.conv1, nn.Conv2d):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
else:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data = self._get_conv_weight_sparsity(self.conv1.module)
conv2_data = self._get_conv_weight_sparsity(self.conv2.module)

print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))

Expand Down
2 changes: 2 additions & 0 deletions examples/model_compress/lottery_torch_mnist_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def test(model, test_loader, criterion):
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()

#model = nn.DataParallel(model)

for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
loss = 0
Expand Down
101 changes: 101 additions & 0 deletions examples/model_compress/multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import SlimPruner

class fc1(nn.Module):

def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)


self.linear1 = nn.Linear(32*28*28, 300)
self.relu2 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(300, 100)
self.relu3 = nn.ReLU(inplace=True)
self.linear3 = nn.Linear(100, num_classes)


def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)

x = torch.flatten(x,1)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
x = self.relu3(x)
x = self.linear3(x)
return x

def train(model, train_loader, optimizer, criterion, device):
model.train()
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()

def test(model, test_loader, criterion, device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy


if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fc1()

criterion = nn.CrossEntropyLoss()

configure_list = [{
'prune_iterations': 5,
'sparsity': 0.86,
'op_types': ['BatchNorm2d']
}]
pruner = SlimPruner(model, configure_list)
pruner.compress()

if torch.cuda.device_count()>1:
model = nn.DataParallel(model)

model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
for name, par in model.named_parameters():
print(name)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion, device)
accuracy = test(model, test_loader, criterion, device)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
45 changes: 29 additions & 16 deletions examples/model_compress/slim_torch_cifar10.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import SlimPruner
from models.cifar10.vgg import VGG


def updateBN(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
Expand Down Expand Up @@ -49,6 +49,13 @@ def test(model, device, test_loader):


def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")

args = parser.parse_args()

torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
Expand All @@ -70,18 +77,19 @@ def main():

model = VGG(depth=19)
model.to(device)

# Train the base VGG-19 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
epochs = 160
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
for epoch in range(epochs):
if epoch in [epochs * 0.5, epochs * 0.75]:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
train(model, device, train_loader, optimizer, True)
test(model, device, test_loader)
torch.save(model.state_dict(), 'vgg19_cifar10.pth')
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
epochs = args.epochs
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
for epoch in range(epochs):
if epoch in [epochs * 0.5, epochs * 0.75]:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
print("epoch {}".format(epoch))
train(model, device, train_loader, optimizer, True)
test(model, device, test_loader)
torch.save(model.state_dict(), 'vgg19_cifar10.pth')

# Test base model accuracy
print('=' * 10 + 'Test the original model' + '=' * 10)
Expand All @@ -94,14 +102,19 @@ def main():
'sparsity': 0.7,
'op_types': ['BatchNorm2d'],
}]

# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
pruner = SlimPruner(model, configure_list)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 93.55%

if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
# model = nn.DataParallel(model, device_ids=[0, 1])
else:
print("only detect 1 gpu, fall back")
model.to(device)
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, model, config_list, activation='relu', statistics_batch_num=1
"""

super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
Expand All @@ -48,22 +48,29 @@ def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = []

def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))

layer.module.register_forward_hook(_hook)
wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model

def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))

def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Expand All @@ -82,14 +89,13 @@ def calc_mask(self, layer, config):
"""

weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
Expand All @@ -104,8 +110,7 @@ def calc_mask(self, layer, config):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask


Expand Down
Loading