diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 43f3d4f6dc9cc..40d01a2ecfc5b 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -244,6 +244,30 @@ class GPTBigCodeModel(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, 'q') + weight_loader(param, loaded_weight, 'k') + weight_loader(param, loaded_weight, 'v') + else: + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} @@ -315,26 +339,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "lm_head.weight" in name: - continue - if ".attn.bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method - if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: - weight_loader(param, loaded_weight, 'q') - weight_loader(param, loaded_weight, 'k') - weight_loader(param, loaded_weight, 'v') - else: - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."]), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 752aec0b223dd..1f73d2ab0f1ae 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -188,6 +188,7 @@ class GPTJModel(nn.Module): quant_config = vllm_config.quant_config self.config = config + self.quant_config = quant_config self.embed_dim = config.n_embd self.wte = VocabParallelEmbedding( config.vocab_size, @@ -228,6 +229,63 @@ class GPTJModel(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "attn.bias" in name or "attn.masked_bias" in name: + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class GPTJForCausalLM(nn.Module, SupportsPP): @@ -285,57 +343,5 @@ class GPTJForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "attn.bias" in name or "attn.masked_bias" in name: - continue - - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file