From eea6a87f9961ad0b646bf10699de1401ad91102a Mon Sep 17 00:00:00 2001 From: JeremieMelo Date: Wed, 24 Mar 2021 23:28:42 -0500 Subject: [PATCH 1/2] [major] pruning trainer --- examples/core/trainers.py | 141 +++++++++++++++++++++++++++++++++++++- 1 file changed, 140 insertions(+), 1 deletion(-) diff --git a/examples/core/trainers.py b/examples/core/trainers.py index 884bea27..417668e1 100644 --- a/examples/core/trainers.py +++ b/examples/core/trainers.py @@ -1,3 +1,4 @@ +from torchquantum.operators import Operation import torch.nn as nn from typing import Any, Callable, Dict @@ -6,9 +7,10 @@ from torchpack.utils.config import configs from torchquantum.utils import get_unitary_loss, legalize_unitary from torchquantum.super_utils import ConfigSampler +from torchquantum.prune_utils import PhaseL1UnstructuredPruningMethod, ThresholdScheduler -__all__ = ['QTrainer', 'LayerRegressionTrainer', 'SuperQTrainer'] +__all__ = ['QTrainer', 'LayerRegressionTrainer', 'SuperQTrainer', 'PruningTrain'] class LayerRegressionTrainer(Trainer): @@ -240,3 +242,140 @@ def _load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.model.load_state_dict(state_dict['model']) self.optimizer.load_state_dict(state_dict['optimizer']) self.scheduler.load_state_dict(state_dict['scheduler']) + + + +class PruningTrainer(Trainer): + ''' + Perform pruning-aware training + ''' + def __init__(self, *, model: nn.Module, criterion: Callable, + optimizer: Optimizer, scheduler: Scheduler) -> None: + self.model = model + self.legalized_model = None + self.criterion = criterion + self.optimizer = optimizer + self.scheduler = scheduler + self.init_pruning() + + def extract_prunable_parameters(self, model: nn.Module) -> tuple: + _parameters_to_prune = ( + (module, "params") + for _, module in model.named_modules() if isinstance(module, Operation) and module.params is not None) + return _parameters_to_prune + + def init_pruning(self) -> None: + """Initialize pruning procedure + """ + self._parameters_to_prune = self.extract_prunable_parameters(self.model) + self._target_pruning_amount = configs.prune.target_pruning_amount + self._init_pruning_amount = configs.prune.init_pruning_amount + self.prune_amount_scheduler = ThresholdScheduler(0, self.num_epochs, self._init_pruning_amount, self._target_pruning_amount) + self.prune_amount = self._init_pruning_amount + + def _remove_pruning(self): + for module, name in self._parameters_to_prune: + nn.utils.prune.remove(module, name) + + def _prune_model(self, prune_amount) -> None: + """Perform global threshold/percentage pruning on the quantum model. This function just performs pruning reparametrization, i.e., record weight_orig and generate weight_mask + """ + ### first clear current prunine container, since we do not want cascaded pruning methods + ### remove operation will make pruning permanent + self._remove_pruning() + ### perform global phase pruning based on the given pruning amount + nn.utils.prune.global_unstructured( + self._parameters_to_prune, + pruning_method=PhaseL1UnstructuredPruningMethod, + amount=prune_amount, + ) + self.summary.add_scalar('prune_amount', prune_amount) + + def _before_epoch(self) -> None: + self.model.train() + + def run_step(self, feed_dict: Dict[str, Any], legalize=False) -> Dict[ + str, Any]: + output_dict = self._run_step(feed_dict, legalize=legalize) + return output_dict + + def _run_step(self, feed_dict: Dict[str, Any], legalize=False) -> Dict[ + str, Any]: + self.sample_config = self.config_sampler.get_sample_config() + self.model.set_sample_config(self.sample_config) + + if configs.run.device == 'gpu': + inputs = feed_dict['image'].cuda(non_blocking=True) + targets = feed_dict['digit'].cuda(non_blocking=True) + else: + inputs = feed_dict['image'] + targets = feed_dict['digit'] + if legalize: + outputs = self.legalized_model(inputs) + else: + outputs = self.model(inputs) + loss = self.criterion(outputs, targets) + nll_loss = loss.item() + unitary_loss = 0 + + if configs.regularization.unitary_loss: + unitary_loss = get_unitary_loss(self.model) + if configs.regularization.unitary_loss_lambda_trainable: + loss += self.model.unitary_loss_lambda[0] * unitary_loss + else: + loss += configs.regularization.unitary_loss_lambda * \ + unitary_loss + + if loss.requires_grad: + self.summary.add_scalar('loss', loss.item()) + self.summary.add_scalar('nll_loss', nll_loss) + + if configs.regularization.unitary_loss: + if configs.regularization.unitary_loss_lambda_trainable: + self.summary.add_scalar( + 'u_loss_lambda', + self.model.unitary_loss_lambda.item()) + else: + self.summary.add_scalar( + 'u_loss_lambda', + configs.regularization.unitary_loss_lambda) + self.summary.add_scalar('u_loss', unitary_loss.item()) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return {'outputs': outputs, 'targets': targets} + + def _after_epoch(self) -> None: + self.model.eval() + self.scheduler.step() + if configs.legalization.legalize: + if self.epoch_num % configs.legalization.epoch_interval == 0: + legalize_unitary(self.model) + ### update pruning amount using the scheduler + self.prune_amount = self.prune_amount_scheduler.step() + ### prune the model + self._prune_model(self.prune_amount) + ### commit pruned parameters after training + if(self.epoch_num == self.num_epochs): + self._remove_pruning() + + def _after_step(self, output_dict) -> None: + if configs.legalization.legalize: + if self.global_step % configs.legalization.step_interval == 0: + legalize_unitary(self.model) + + def _state_dict(self) -> Dict[str, Any]: + state_dict = dict() + # need to store model arch because of randomness of random layers + state_dict['model_arch'] = self.model + state_dict['model'] = self.model.state_dict() + state_dict['optimizer'] = self.optimizer.state_dict() + state_dict['scheduler'] = self.scheduler.state_dict() + return state_dict + + def _load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.model.load_state_dict(state_dict['model']) + self.optimizer.load_state_dict(state_dict['optimizer']) + self.scheduler.load_state_dict(state_dict['scheduler']) From 9fdefcf2f7f281d8cc6785f573a7f6a1dd08ed92 Mon Sep 17 00:00:00 2001 From: JeremieMelo Date: Wed, 24 Mar 2021 23:28:56 -0500 Subject: [PATCH 2/2] [major] pruning method and scheduler --- torchquantum/prune_utils.py | 90 +++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 torchquantum/prune_utils.py diff --git a/torchquantum/prune_utils.py b/torchquantum/prune_utils.py new file mode 100644 index 00000000..c0109554 --- /dev/null +++ b/torchquantum/prune_utils.py @@ -0,0 +1,90 @@ +''' +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-03-24 21:52:50 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-03-24 22:33:25 +''' +import torch +import numpy as np +import tensorflow as tf +import tensorflow_model_optimization as tfmot + +__all__ = ["PhaseL1UnstructuredPruningMethod", "ThresholdScheduler"] + +class PhaseL1UnstructuredPruningMethod(torch.nn.utils.prune.BasePruningMethod): + """Prune rotation phases which is close to 0, +2pi, .. + """ + PRUNING_TYPE = 'unstructured' + + def __init__(self, amount): + # Check range of validity of pruning amount + torch.nn.utils.prune._validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + t = t % (2 * np.pi) + t[t > np.pi] -= 2 * np.pi + + tensor_size = t.numel() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = torch.nn.utils.prune_compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + torch.nn.utils.prune_validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + # largest=True --> top k; largest=False --> bottom k + # Prune the smallest k + topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) + # topk will have .indices and .values + mask.view(-1)[topk.indices] = 0 + + return mask + + +class ThresholdScheduler(object): + ''' smooth increasing threshold with tensorflow model pruning scheduler + ''' + + def __init__(self, step_beg, step_end, thres_beg, thres_end): + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + tf.enable_eager_execution(config=config) + self.step_beg = step_beg + self.step_end = step_end + self.thres_beg = thres_beg + self.thres_end = thres_end + if(thres_beg < thres_end): + self.thres_min = thres_beg + self.thres_range = (thres_end - thres_beg) + self.descend = False + + else: + self.thres_min = thres_end + self.thres_range = (thres_beg - thres_end) + self.descend = True + + self.pruning_schedule = tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=0, final_sparsity=0.9999999, + begin_step=self.step_beg, end_step=self.step_end) + self.global_step = 0 + + def step(self): + if(self.global_step < self.step_beg): + return self.thres_beg + elif(self.global_step > self.step_end): + return self.thres_end + res_norm = self.pruning_schedule(self.global_step)[1].numpy() + if(self.descend == False): + res = res_norm * self.thres_range + self.thres_beg + else: + res = self.thres_beg - res_norm * self.thres_range + + if(np.abs(res - self.thres_end) <= 1e-6): + res = self.thres_end + self.global_step += 1 + return res