From 6aa6de380b299493ac693c1b729d0e06f31a4312 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 1 May 2024 09:52:38 -0700 Subject: [PATCH] remove dummy path in arctic --- vllm/model_executor/models/arctic.py | 64 +++++++++++++--------------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 208279170a739..b6d6f0b8deaa4 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -481,53 +481,49 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) - if use_dummy: - logger.info("Using dummy weights. Skip loading weights.") - else: - logger.info( - "It takes ~10 minutes to load the weights. Please be patient.") - for name, loaded_weight in weights: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: + logger.info( + "It takes ~10 minutes to load the weights. Please be patient.") + for name, loaded_weight in weights: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, shard_id in mlp_params_mapping: + for param_name, weight_name, shard_id \ + in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) break else: - for param_name, weight_name, shard_id \ - in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=shard_id) - break - else: - if name.endswith( - ".bias") and name not in params_dict: - continue - param = params_dict[name] - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if name.endswith( + ".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)