From 66f8cbb2ff4aa76f13e80130305e3de6327ace5f Mon Sep 17 00:00:00 2001 From: VainF Date: Sun, 17 Nov 2024 08:55:14 +0800 Subject: [PATCH] Add SliceOp & Support Phi-3 --- .../LLMs/{prune_llama.py => prune_llm.py} | 41 ++++++- examples/LLMs/readme.md | 116 +++++++++++++++++- .../torchvision_global_pruning.py | 2 - .../torchvision_models/torchvision_pruning.py | 2 - examples/transformers/prune_timm_vit.py | 5 +- torch_pruning/_helpers.py | 16 +++ torch_pruning/dependency.py | 52 +++++++- torch_pruning/ops.py | 34 +++++ torch_pruning/pruner/algorithms/metapruner.py | 8 +- torch_pruning/pruner/importance.py | 1 + 10 files changed, 256 insertions(+), 21 deletions(-) rename examples/LLMs/{prune_llama.py => prune_llm.py} (90%) diff --git a/examples/LLMs/prune_llama.py b/examples/LLMs/prune_llm.py similarity index 90% rename from examples/LLMs/prune_llama.py rename to examples/LLMs/prune_llm.py index 54e9ac2..0b0d091 100644 --- a/examples/LLMs/prune_llama.py +++ b/examples/LLMs/prune_llm.py @@ -291,11 +291,22 @@ def main(): inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device) import torch_pruning as tp num_heads = {} + out_channel_groups = {} + seperate_qkv = False for name, m in model.named_modules(): if name.endswith("self_attn"): - num_heads[m.q_proj] = model.config.num_attention_heads - num_heads[m.k_proj] = model.config.num_key_value_heads - num_heads[m.v_proj] = model.config.num_key_value_heads + if hasattr(m, "q_proj"): + seperate_qkv = True + num_heads[m.q_proj] = model.config.num_attention_heads + num_heads[m.k_proj] = model.config.num_key_value_heads + num_heads[m.v_proj] = model.config.num_key_value_heads + elif hasattr(m, "qkv_proj"): + seperate_qkv = False + num_heads[m.qkv_proj] = model.config.num_attention_heads + if name.endswith('mlp'): + if hasattr(m, "gate_up_proj"): + out_channel_groups[m.gate_up_proj] = 2 + _is_gqa = model.config.num_attention_heads != model.config.num_key_value_heads head_pruning_ratio = args.pruning_ratio hidden_size_pruning_ratio = args.pruning_ratio @@ -311,19 +322,31 @@ def main(): prune_num_heads=True, prune_head_dims=False, # we do not prune head dims so that we don't need to prune the ROPE head_pruning_ratio=head_pruning_ratio, + out_channel_groups=out_channel_groups ) + + #with torch.no_grad(): # with importance.compute_importance(model): # calibration_data = "We recommend at least a 1TB hard drive for 4 channels, more if you plan on using 8MP \/ 4K cameras.\nDahua's Lite Series network video recorders offer excellent performance and high recording quality for IP video surveillance applications. For applications where details are critical for identification, this professional NVR provides a powerful processor with up to 4K resolution. Additionally, the NVR features a mouse shortcut operation menu, remote management and control, center storage, edge storage, and back up storage." # calibration_data = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device) # _ = model(calibration_data) - pruner.step() + + #group = pruner.DG.get_pruning_group(model.model.layers[31].mlp.gate_up_proj, tp.prune_linear_out_channels, idxs=list(range(16384))) + #print(group) + + for g in pruner.step(interactive=True): + print(g) + g.prune() # Update model attributes model.config.hidden_size = model.lm_head.in_features for name, m in model.named_modules(): if name.endswith("self_attn"): - m.hidden_size = m.q_proj.out_features + if seperate_qkv: + m.hidden_size = m.q_proj.out_features + else: + m.hidden_size = m.qkv_proj.out_features // 3 m.num_heads = m.hidden_size // m.head_dim model.config.num_attention_heads = m.num_heads #m.head_dim = m.q_proj.out_features // m.num_heads @@ -331,7 +354,13 @@ def main(): m.num_key_value_heads = m.num_heads m.num_key_value_groups = m.num_heads // m.num_key_value_heads elif name.endswith("mlp"): - model.config.intermediate_size = m.gate_proj.out_features + if hasattr(m, "gate_proj"): + m.hidden_size = m.gate_proj.out_features + elif hasattr(m, "gate_up_proj"): + m.hidden_size = m.gate_up_proj.in_features + else: + raise ValueError("Unknown mlp layer") + if not _is_gqa: model.config.num_key_value_heads = model.config.num_attention_heads print("----------------- After Pruning -----------------") diff --git a/examples/LLMs/readme.md b/examples/LLMs/readme.md index 5f1fbb4..04b357e 100644 --- a/examples/LLMs/readme.md +++ b/examples/LLMs/readme.md @@ -13,7 +13,7 @@ pip install transformers datasets ### Llama-3 8B ```bash -python prune_llama.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5 +python prune_llm.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5 ```
@@ -120,7 +120,7 @@ wikitext perplexity 552648.25 ### Llama-2 7B ```bash -python prune_llama.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5 +python prune_llm.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5 ``` @@ -224,3 +224,115 @@ wikitext perplexity 8479.0673828125
+### microsoft/Phi-3-mini-4k-instruct + +```bash +python prune_llm.py --model microsoft/Phi-3-mini-4k-instruct --pruning_ratio 0.5 +``` + + +
+Output: + +``` +----------------- Before Pruning ----------------- +Phi3ForCausalLM( + (model): Phi3Model( + (embed_tokens): Embedding(32064, 3072, padding_idx=32000) + (embed_dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-31): 32 x Phi3DecoderLayer( + (self_attn): Phi3Attention( + (o_proj): Linear(in_features=3072, out_features=3072, bias=False) + (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False) + (rotary_emb): Phi3RotaryEmbedding() + ) + (mlp): Phi3MLP( + (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False) + (down_proj): Linear(in_features=8192, out_features=3072, bias=False) + (activation_fn): SiLU() + ) + (input_layernorm): Phi3RMSNorm() + (resid_attn_dropout): Dropout(p=0.0, inplace=False) + (resid_mlp_dropout): Dropout(p=0.0, inplace=False) + (post_attention_layernorm): Phi3RMSNorm() + ) + ) + (norm): Phi3RMSNorm() + ) + (lm_head): Linear(in_features=3072, out_features=32064, bias=False) +) +----------------- After Pruning ----------------- +Token indices sequence length is longer than the specified maximum sequence length for this model (2824490 > 4096). Running this sequence through the model will result in indexing errors +Phi3ForCausalLM( + (model): Phi3Model( + (embed_tokens): Embedding(32064, 1536, padding_idx=32000) + (embed_dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-31): 32 x Phi3DecoderLayer( + (self_attn): Phi3Attention( + (o_proj): Linear(in_features=1536, out_features=1536, bias=False) + (qkv_proj): Linear(in_features=1536, out_features=4608, bias=False) + (rotary_emb): Phi3RotaryEmbedding() + ) + (mlp): Phi3MLP( + (gate_up_proj): Linear(in_features=1536, out_features=8192, bias=False) + (down_proj): Linear(in_features=4096, out_features=1536, bias=False) + (activation_fn): SiLU() + ) + (input_layernorm): Phi3RMSNorm() + (resid_attn_dropout): Dropout(p=0.0, inplace=False) + (resid_mlp_dropout): Dropout(p=0.0, inplace=False) + (post_attention_layernorm): Phi3RMSNorm() + ) + ) + (norm): Phi3RMSNorm() + ) + (lm_head): Linear(in_features=1536, out_features=32064, bias=False) +) +Phi3Config { + "_name_or_path": "microsoft/Phi-3-mini-4k-instruct", + "architectures": [ + "Phi3ForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "microsoft/Phi-3-mini-4k-instruct--configuration_phi3.Phi3Config", + "AutoModelForCausalLM": "microsoft/Phi-3-mini-4k-instruct--modeling_phi3.Phi3ForCausalLM" + }, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 1536, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 16, + "num_hidden_layers": 32, + "num_key_value_heads": 16, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.36.2", + "use_cache": true, + "vocab_size": 32064 +} + +num_params 1004570112 +evaluating on wikitext2 +nsamples 83 +sample 0 +sample 50 +wikitext perplexity 92795.3984375 +``` + +
\ No newline at end of file diff --git a/examples/torchvision_models/torchvision_global_pruning.py b/examples/torchvision_models/torchvision_global_pruning.py index cd05c56..5944c96 100644 --- a/examples/torchvision_models/torchvision_global_pruning.py +++ b/examples/torchvision_models/torchvision_global_pruning.py @@ -16,8 +16,6 @@ ) from torchvision.models.detection.fcos import fcos_resnet50_fpn from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn -from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2 -from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2 from torchvision.models.alexnet import alexnet from torchvision.models.vision_transformer import ( diff --git a/examples/torchvision_models/torchvision_pruning.py b/examples/torchvision_models/torchvision_pruning.py index 419a0dc..4551a91 100644 --- a/examples/torchvision_models/torchvision_pruning.py +++ b/examples/torchvision_models/torchvision_pruning.py @@ -16,8 +16,6 @@ ) from torchvision.models.detection.fcos import fcos_resnet50_fpn from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn -from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2 -from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2 from torchvision.models.alexnet import alexnet from torchvision.models.vision_transformer import ( diff --git a/examples/transformers/prune_timm_vit.py b/examples/transformers/prune_timm_vit.py index 49f120b..c3c360d 100644 --- a/examples/transformers/prune_timm_vit.py +++ b/examples/transformers/prune_timm_vit.py @@ -151,7 +151,7 @@ def main(): prune_num_heads=args.prune_num_heads, # reduce num_heads by pruning entire heads (default: False) prune_head_dims=not args.prune_num_heads, # reduce head_dim by pruning featrues dims of each head (default: True) head_pruning_ratio=0.5, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0) - round_to=2 + round_to=1 ) if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)): @@ -206,6 +206,9 @@ def main(): print("Base Loss: %.4f, Pruned Loss: %.4f"%(loss_ori, loss_pruned)) print("Base Accuracy: %.4f, Pruned Accuracy: %.4f"%(acc_ori, acc_pruned)) + latency_mean, latency_std = tp.utils.benchmark.measure_latency(model, example_inputs=torch.randn(16,3,224,224).to(device), repeat=300) + print("Latency: %.4f ms, Std: %.4f ms"%(latency_mean, latency_std)) + if args.save_as is not None: print("Saving the pruned model to %s..."%args.save_as) os.makedirs(os.path.dirname(args.save_as), exist_ok=True) diff --git a/torch_pruning/_helpers.py b/torch_pruning/_helpers.py index f0345ca..15adf15 100644 --- a/torch_pruning/_helpers.py +++ b/torch_pruning/_helpers.py @@ -96,6 +96,22 @@ def __call__(self, idxs: _HybridIndex): return new_idxs +class _SliceIndexMapping(object): + def __init__(self, dim, start, step, end, reverse=False): + self.start = start + self.step = step + self.end = end + self.reverse = reverse + self.dim = dim + + def __call__(self, idxs: _HybridIndex): + + if self.reverse == True: + new_idxs = [ _HybridIndex(idx=i.idx * self.step + self.start, root_idx=i.root_idx) for i in idxs] + else: + new_idxs = [ _HybridIndex(idx=(i.idx - self.start) // self.step, root_idx=i.root_idx) for i in idxs if (i.idx >= self.start and i.idx < self.end and (i.idx-self.start)%self.step==0) ] + return new_idxs + class _SplitIndexMapping(object): def __init__(self, offset, reverse=False): self.offset = offset diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index 588ce65..3114cf2 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -294,6 +294,7 @@ def __init__(self): ops.OPTYPE.UNBIND: ops.UnbindPruner(), ops.OPTYPE.EXPAND: ops.ExpandPruner(), ops.OPTYPE.CUSTOMIZED: ops.CustomizedPruner(), # just a placeholder + ops.OPTYPE.SLICE: ops.SlicePruner(), } self.REGISTERED_PRUNERS = function.PrunerBox.copy() # shallow copy self.REGISTERED_PRUNERS.update(_dummy_pruners) # merge dummy pruners @@ -511,7 +512,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args): ) _fix_dependency_graph_non_recursive(*group[0]) - + # merge pruning ops merged_group = Group() # craft a new group for merging for dep, idxs in group.items: @@ -827,6 +828,7 @@ def create_node_if_not_exists(grad_fn): # 1. link grad_fns and modules if module is None: # a new module + if not hasattr(grad_fn, "name"): # we treat all unknwon modules as element-wise operations by default, # which does not modify the #dimension/#channel of features. @@ -853,6 +855,12 @@ def create_node_if_not_exists(grad_fn): elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower(): module = ops._ReshapeOp(self._op_id) self._op_id+=1 + elif "slice" in grad_fn.name().lower() and "copyslices" not in grad_fn.name().lower(): + if hasattr(grad_fn, '_saved_start') and hasattr(grad_fn, '_saved_end') and hasattr(grad_fn, '_saved_step') and hasattr(grad_fn, '_saved_dim'): + module = ops._SliceOp(self._op_id, grad_fn) + else: # for old version of pytorch, we can not handle the slice operation + module = ops._ElementWiseOp(self._op_id, grad_fn.name()) + self._op_id+=1 else: # treate other ops as element-wise ones, like Add, Sub, Div, Mul. module = ops._ElementWiseOp(self._op_id, grad_fn.name()) @@ -924,6 +932,32 @@ def update_index_mapping(self): self._update_unbind_index_mapping(node) if node.type == ops.OPTYPE.EXPAND and torch.__version__ >= "1.8": self._update_expand_index_mapping(node) + if node.type == ops.OPTYPE.SLICE: + self._update_slice_index_mapping(node) + + + def _update_slice_index_mapping(self, slice_node: Node): + if slice_node.type != ops.OPTYPE.SLICE: + return + grad_fn = slice_node.grad_fn + if hasattr(grad_fn, '_saved_self_sym_sizes'): + if len(grad_fn._saved_self_sym_sizes)==4 and grad_fn._saved_dim != 1: + return + elif len(grad_fn._saved_self_sym_sizes)==3 and grad_fn._saved_dim != 2: + return + + start, step, end, dim = slice_node.module.start, slice_node.module.step, slice_node.module.end, slice_node.module.dim + for node in slice_node.inputs: + for dep in slice_node.dependencies: + if dep.target == node: + dep.index_mapping[0] = _helpers._SliceIndexMapping( + dim=dim, start=start, end=end, step=step, reverse=True + ) + for dep in node.dependencies: + if dep.target == slice_node: + dep.index_mapping[0] = _helpers._SliceIndexMapping( + dim=dim, start=start, end=end, step=step, reverse=False + ) def _init_shape_information(self): for module, node in self.module2node.items(): @@ -1111,10 +1145,18 @@ def _update_concat_index_mapping(self, cat_node: Node): def _update_split_index_mapping(self, split_node: Node): if split_node.type != ops.OPTYPE.SPLIT: return - - if hasattr(split_node.grad_fn, '_saved_dim') and split_node.grad_fn._saved_dim != 1: # this only works for Pytorch>=1.12 - return - + + if hasattr(split_node.grad_fn, '_saved_dim'): # this only works for Pytorch>=1.12 + + # There a issue in some pytorch version, where the _saved_dim is an uninitialized value like 118745347895359 + # So we need to check if the _saved_dim is a valid value (=0 and split_node.grad_fn._saved_dim != 1: + return offsets = split_node.module.offsets if offsets is None: diff --git a/torch_pruning/ops.py b/torch_pruning/ops.py index a54f076..3697f9d 100644 --- a/torch_pruning/ops.py +++ b/torch_pruning/ops.py @@ -53,6 +53,18 @@ def __init__(self, id): def __repr__(self): return "_Reshape_{}()".format(self.id) +class _SliceOp(nn.Module): + def __init__(self, id, grad_fn): + super(_SliceOp, self).__init__() + self.grad_fn = grad_fn + self.id = id + self.start = grad_fn._saved_start + self.end = grad_fn._saved_end + self.step = grad_fn._saved_step + self.dim = grad_fn._saved_dim + + def __repr__(self): + return "_Slice_{}()".format(self.id) class _ElementWiseOp(nn.Module): def __init__(self, id, grad_fn): @@ -124,6 +136,23 @@ def prune_out_channels(self, layer, idxs): prune_in_channels = prune_out_channels +class SlicePruner(DummyPruner): + def prune_out_channels(self, layer, idxs): + if layer.grad_fn is None: + return + offset_start = 0 + offset_end = 0 + for i in idxs: + if i < layer.start: + offset_start += 1 + offset_end += 1 + elif i >= layer.start and i < layer.end: + offset_end += layer.step + layer.start -= offset_start + layer.end -= offset_end + + prune_in_channels = prune_out_channels + class SplitPruner(DummyPruner): def prune_out_channels(self, layer, idxs): if layer.split_sizes is None: @@ -199,6 +228,7 @@ class OPTYPE(IntEnum): IN = 16 # nn.InstanceNorm UNBIND = 17 EXPAND = 18 + SLICE = 19 def module2type(module): @@ -239,6 +269,8 @@ def module2type(module): return OPTYPE.UNBIND elif isinstance(module, _ExpandOp): return OPTYPE.EXPAND + elif isinstance(module, _SliceOp): + return OPTYPE.SLICE else: return OPTYPE.ELEMENTWISE @@ -278,6 +310,8 @@ def type2class(op_type): return _UnbindOp elif OPTYPE == OPTYPE.EXPAND: return _ExpandOp + elif OPTYPE == OPTYPE.SLICE: + return _SliceOp else: return _ElementWiseOp diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index 5efea18..7e92bda 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -262,8 +262,8 @@ def step(self, interactive=False)-> typing.Union[typing.Generator, None]: else: for group in self._prune(): group.prune() - - def manual_prune(self, layer, pruning_fn, pruning_ratios_or_idxs): + + def manual_prune_width(self, layer, pruning_fn, pruning_ratios_or_idxs): if isinstance(pruning_ratios_or_idxs, float): if self.DG.is_out_channel_pruning_fn(pruning_fn): prunable_channels = self.DG.get_out_channels(layer) @@ -274,7 +274,7 @@ def manual_prune(self, layer, pruning_fn, pruning_ratios_or_idxs): imp_argsort = torch.argsort(imp) n_pruned = int(prunable_channels * (1 - pruning_ratios_or_idxs)) pruning_idxs = imp_argsort[:n_pruned] - + group = self.DG.get_pruning_group(layer, pruning_fn, pruning_idxs) group.prune() @@ -412,6 +412,7 @@ def _prune(self) -> typing.Generator: # Re-order the group and use a downstream node as the root node for attention layers. # This will not change the group structure, but make index mapping easier for attention layers. _is_atten, qkv_layers = self._is_atten_group(group) + if _is_atten: group = self._downstream_node_as_root_if_attention(group) if group is None: continue @@ -622,3 +623,4 @@ def _prune(self) -> typing.Generator: if self.DG.check_pruning_group(group): yield group # yield the group for interactive pruning + diff --git a/torch_pruning/pruner/importance.py b/torch_pruning/pruner/importance.py index 8b4969f..c58b094 100644 --- a/torch_pruning/pruner/importance.py +++ b/torch_pruning/pruner/importance.py @@ -181,6 +181,7 @@ def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[ def __call__(self, group: Group): group_imp = [] group_idxs = [] + # Iterate over all groups and estimate group importance for i, (dep, idxs) in enumerate(group): layer = dep.layer