Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[major] merge pruning branch to master #2

Merged
merged 3 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion examples/core/trainers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torchquantum.operators import Operation
import torch.nn as nn

from typing import Any, Callable, Dict
Expand All @@ -6,9 +7,11 @@
from torchpack.utils.config import configs
from torchquantum.utils import get_unitary_loss, legalize_unitary
from torchquantum.super_utils import ArchSampler
from torchquantum.prune_utils import PhaseL1UnstructuredPruningMethod, ThresholdScheduler


__all__ = ['QTrainer', 'LayerRegressionTrainer', 'SuperQTrainer']
__all__ = ['QTrainer', 'LayerRegressionTrainer', 'SuperQTrainer',
'PruningTrainer']


class LayerRegressionTrainer(Trainer):
Expand Down Expand Up @@ -254,3 +257,140 @@ def _load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.model.load_state_dict(state_dict['model'])
self.optimizer.load_state_dict(state_dict['optimizer'])
self.scheduler.load_state_dict(state_dict['scheduler'])



class PruningTrainer(Trainer):
'''
Perform pruning-aware training
'''
def __init__(self, *, model: nn.Module, criterion: Callable,
optimizer: Optimizer, scheduler: Scheduler) -> None:
self.model = model
self.legalized_model = None
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = scheduler
self.init_pruning()

def extract_prunable_parameters(self, model: nn.Module) -> tuple:
_parameters_to_prune = (
(module, "params")
for _, module in model.named_modules() if isinstance(module, Operation) and module.params is not None)
return _parameters_to_prune

def init_pruning(self) -> None:
"""Initialize pruning procedure
"""
self._parameters_to_prune = self.extract_prunable_parameters(self.model)
self._target_pruning_amount = configs.prune.target_pruning_amount
self._init_pruning_amount = configs.prune.init_pruning_amount
self.prune_amount_scheduler = ThresholdScheduler(0, self.num_epochs, self._init_pruning_amount, self._target_pruning_amount)
self.prune_amount = self._init_pruning_amount

def _remove_pruning(self):
for module, name in self._parameters_to_prune:
nn.utils.prune.remove(module, name)

def _prune_model(self, prune_amount) -> None:
"""Perform global threshold/percentage pruning on the quantum model. This function just performs pruning reparametrization, i.e., record weight_orig and generate weight_mask
"""
### first clear current prunine container, since we do not want cascaded pruning methods
### remove operation will make pruning permanent
self._remove_pruning()
### perform global phase pruning based on the given pruning amount
nn.utils.prune.global_unstructured(
self._parameters_to_prune,
pruning_method=PhaseL1UnstructuredPruningMethod,
amount=prune_amount,
)
self.summary.add_scalar('prune_amount', prune_amount)

def _before_epoch(self) -> None:
self.model.train()

def run_step(self, feed_dict: Dict[str, Any], legalize=False) -> Dict[
str, Any]:
output_dict = self._run_step(feed_dict, legalize=legalize)
return output_dict

def _run_step(self, feed_dict: Dict[str, Any], legalize=False) -> Dict[
str, Any]:
self.sample_config = self.config_sampler.get_sample_config()
self.model.set_sample_config(self.sample_config)

if configs.run.device == 'gpu':
inputs = feed_dict['image'].cuda(non_blocking=True)
targets = feed_dict['digit'].cuda(non_blocking=True)
else:
inputs = feed_dict['image']
targets = feed_dict['digit']
if legalize:
outputs = self.legalized_model(inputs)
else:
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
nll_loss = loss.item()
unitary_loss = 0

if configs.regularization.unitary_loss:
unitary_loss = get_unitary_loss(self.model)
if configs.regularization.unitary_loss_lambda_trainable:
loss += self.model.unitary_loss_lambda[0] * unitary_loss
else:
loss += configs.regularization.unitary_loss_lambda * \
unitary_loss

if loss.requires_grad:
self.summary.add_scalar('loss', loss.item())
self.summary.add_scalar('nll_loss', nll_loss)

if configs.regularization.unitary_loss:
if configs.regularization.unitary_loss_lambda_trainable:
self.summary.add_scalar(
'u_loss_lambda',
self.model.unitary_loss_lambda.item())
else:
self.summary.add_scalar(
'u_loss_lambda',
configs.regularization.unitary_loss_lambda)
self.summary.add_scalar('u_loss', unitary_loss.item())

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

return {'outputs': outputs, 'targets': targets}

def _after_epoch(self) -> None:
self.model.eval()
self.scheduler.step()
if configs.legalization.legalize:
if self.epoch_num % configs.legalization.epoch_interval == 0:
legalize_unitary(self.model)
### update pruning amount using the scheduler
self.prune_amount = self.prune_amount_scheduler.step()
### prune the model
self._prune_model(self.prune_amount)
### commit pruned parameters after training
if(self.epoch_num == self.num_epochs):
self._remove_pruning()

def _after_step(self, output_dict) -> None:
if configs.legalization.legalize:
if self.global_step % configs.legalization.step_interval == 0:
legalize_unitary(self.model)

def _state_dict(self) -> Dict[str, Any]:
state_dict = dict()
# need to store model arch because of randomness of random layers
state_dict['model_arch'] = self.model
state_dict['model'] = self.model.state_dict()
state_dict['optimizer'] = self.optimizer.state_dict()
state_dict['scheduler'] = self.scheduler.state_dict()
return state_dict

def _load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.model.load_state_dict(state_dict['model'])
self.optimizer.load_state_dict(state_dict['optimizer'])
self.scheduler.load_state_dict(state_dict['scheduler'])
90 changes: 90 additions & 0 deletions torchquantum/prune_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
'''
Description:
Author: Jiaqi Gu (jqgu@utexas.edu)
Date: 2021-03-24 21:52:50
LastEditors: Jiaqi Gu (jqgu@utexas.edu)
LastEditTime: 2021-03-24 22:33:25
'''
import torch
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

__all__ = ["PhaseL1UnstructuredPruningMethod", "ThresholdScheduler"]

class PhaseL1UnstructuredPruningMethod(torch.nn.utils.prune.BasePruningMethod):
"""Prune rotation phases which is close to 0, +2pi, ..
"""
PRUNING_TYPE = 'unstructured'

def __init__(self, amount):
# Check range of validity of pruning amount
torch.nn.utils.prune._validate_pruning_amount_init(amount)
self.amount = amount

def compute_mask(self, t, default_mask):
t = t % (2 * np.pi)
t[t > np.pi] -= 2 * np.pi

tensor_size = t.numel()
# Compute number of units to prune: amount if int,
# else amount * tensor_size
nparams_toprune = torch.nn.utils.prune_compute_nparams_toprune(self.amount, tensor_size)
# This should raise an error if the number of units to prune is larger
# than the number of units in the tensor
torch.nn.utils.prune_validate_pruning_amount(nparams_toprune, tensor_size)

mask = default_mask.clone(memory_format=torch.contiguous_format)

if nparams_toprune != 0: # k=0 not supported by torch.kthvalue
# largest=True --> top k; largest=False --> bottom k
# Prune the smallest k
topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False)
# topk will have .indices and .values
mask.view(-1)[topk.indices] = 0

return mask


class ThresholdScheduler(object):
''' smooth increasing threshold with tensorflow model pruning scheduler
'''

def __init__(self, step_beg, step_end, thres_beg, thres_end):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.enable_eager_execution(config=config)
self.step_beg = step_beg
self.step_end = step_end
self.thres_beg = thres_beg
self.thres_end = thres_end
if(thres_beg < thres_end):
self.thres_min = thres_beg
self.thres_range = (thres_end - thres_beg)
self.descend = False

else:
self.thres_min = thres_end
self.thres_range = (thres_beg - thres_end)
self.descend = True

self.pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0, final_sparsity=0.9999999,
begin_step=self.step_beg, end_step=self.step_end)
self.global_step = 0

def step(self):
if(self.global_step < self.step_beg):
return self.thres_beg
elif(self.global_step > self.step_end):
return self.thres_end
res_norm = self.pruning_schedule(self.global_step)[1].numpy()
if(self.descend == False):
res = res_norm * self.thres_range + self.thres_beg
else:
res = self.thres_beg - res_norm * self.thres_range

if(np.abs(res - self.thres_end) <= 1e-6):
res = self.thres_end
self.global_step += 1
return res