diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 96530562b072c..f8ea2111fed57 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -626,9 +626,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - def permute(w: torch.Tensor, n_heads: int): + def permute(w: torch.Tensor, n_heads: int, attn_out: int): attn_in = self.config.head_dim * n_heads - attn_out = self.config.hidden_size return w.view(n_heads, attn_in // n_heads // 2, 2, attn_out).transpose(1, 2).reshape(attn_in, attn_out) @@ -637,12 +636,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): modules = name.split(".") # rotary embeds should be sliced + # If using quantized model in mistral format, + # quantization scales (qscale_weight) also need to be sliced if "wk" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads) + self.config.num_key_value_heads, + self.config.hidden_size) + elif "wk" in modules and modules[ + -1] == "qscale_weight" and loaded_weight.numel() > 1: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads, 1) elif "wq" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, - self.config.num_attention_heads) + self.config.num_attention_heads, + self.config.hidden_size) + elif "wq" in modules and modules[ + -1] == "qscale_weight" and loaded_weight.numel() > 1: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads, 1) num_modules = len(modules) for i in range(num_modules): diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 27e8b6fa55351..1ea317c2f95f9 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -23,6 +23,7 @@ from transformers.tokenization_utils_base import TextInput from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.inputs.data import PromptType from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsPP from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -327,6 +328,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, super().__init__() self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) + # update quant config to so that ignored module and target module names + # match the vLLM model names + if hasattr(vllm_config, "quant_config"): + vllm_config.quant_config = self.maybe_update_quant_config( + vllm_config.quant_config) + config = vllm_config.model_config.hf_config self.config = config self.downsample_factor = self.config.audio_config.downsample_factor @@ -558,6 +565,72 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, return loaded_weights + def maybe_update_quant_config( + self, quant_config: QuantizationConfig) -> QuantizationConfig: + """ + Update quant config to so that ignored module and target module names + match the vLLM model names. + Right now this is specific for compressed-tensors format and + load_format mistral. + """ + remapping_rules = [ + (r"output", r"language_model.lm_head"), + (r"layers\.(\d+)\.attention\.wo", + r"language_model.model.layers.\1.self_attn.out_proj"), + (r"layers\.(\d+)\.attention\.w(.*)", + r"language_model.model.layers.\1.self_attn.\2_proj"), + (r"layers\.(\d+)\.feed_forward\.w1", + r"language_model.model.layers.\1.mlp.gate_proj"), + (r"layers\.(\d+)\.feed_forward\.w2", + r"language_model.model.layers.\1.mlp.down_proj"), + (r"layers\.(\d+)\.feed_forward\.w3", + r"language_model.model.layers.\1.mlp.up_proj"), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj" + ), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj" + ), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"), + (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0", + r"whisper_encoder.whisper_encoder.conv1"), + (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1", + r"whisper_encoder.whisper_encoder.conv2"), + (r"mm_whisper_embeddings\.audio_language_projection\.0", + r"audio_language_adapter.w_in"), + (r"mm_whisper_embeddings\.audio_language_projection\.2", + r"audio_language_adapter.w_out"), + ] + + # Update ignore list + if hasattr(quant_config, "ignore"): + mistral_ignore = [] + for name in quant_config.ignore: + mistral_name = name + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + mistral_name = re.sub(pattern, repl, name) + mistral_ignore.append(mistral_name) + quant_config.ignore = mistral_ignore + + # Update target list + if hasattr(quant_config, "config_groups"): + config_groups = quant_config.config_groups + for group_name in config_groups: + if "targets" in config_groups[group_name]: + targets = [] + for name in config_groups[group_name]["targets"]: + mistral_name = name + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + mistral_name = re.sub(pattern, repl, name) + targets.append(mistral_name) + config_groups[group_name]["targets"] = targets + quant_config.config_groups = config_groups + + return quant_config + class AudioLanguageAdapter(nn.Module):