[Models][Quantization] Add quantization configuration update in Voxtral model (#24122)

Signed-off-by: Alexandre Marques <almarque@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Alexandre Marques 2025-09-10 22:13:56 -04:00 committed by GitHub
parent cc99baf14d
commit 5931b7e5d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 4 deletions

View File

@ -626,9 +626,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
@ -637,12 +636,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
modules = name.split(".")
# rotary embeds should be sliced
# If using quantized model in mistral format,
# quantization scales (qscale_weight) also need to be sliced
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
self.config.num_key_value_heads,
self.config.hidden_size)
elif "wk" in modules and modules[
-1] == "qscale_weight" and loaded_weight.numel() > 1:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads, 1)
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
self.config.num_attention_heads,
self.config.hidden_size)
elif "wq" in modules and modules[
-1] == "qscale_weight" and loaded_weight.numel() > 1:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads, 1)
num_modules = len(modules)
for i in range(num_modules):

View File

@ -23,6 +23,7 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -327,6 +328,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
super().__init__()
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
# update quant config to so that ignored module and target module names
# match the vLLM model names
if hasattr(vllm_config, "quant_config"):
vllm_config.quant_config = self.maybe_update_quant_config(
vllm_config.quant_config)
config = vllm_config.model_config.hf_config
self.config = config
self.downsample_factor = self.config.audio_config.downsample_factor
@ -558,6 +565,72 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return loaded_weights
def maybe_update_quant_config(
self, quant_config: QuantizationConfig) -> QuantizationConfig:
"""
Update quant config to so that ignored module and target module names
match the vLLM model names.
Right now this is specific for compressed-tensors format and
load_format mistral.
"""
remapping_rules = [
(r"output", r"language_model.lm_head"),
(r"layers\.(\d+)\.attention\.wo",
r"language_model.model.layers.\1.self_attn.out_proj"),
(r"layers\.(\d+)\.attention\.w(.*)",
r"language_model.model.layers.\1.self_attn.\2_proj"),
(r"layers\.(\d+)\.feed_forward\.w1",
r"language_model.model.layers.\1.mlp.gate_proj"),
(r"layers\.(\d+)\.feed_forward\.w2",
r"language_model.model.layers.\1.mlp.down_proj"),
(r"layers\.(\d+)\.feed_forward\.w3",
r"language_model.model.layers.\1.mlp.up_proj"),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
r"whisper_encoder.whisper_encoder.conv1"),
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
r"whisper_encoder.whisper_encoder.conv2"),
(r"mm_whisper_embeddings\.audio_language_projection\.0",
r"audio_language_adapter.w_in"),
(r"mm_whisper_embeddings\.audio_language_projection\.2",
r"audio_language_adapter.w_out"),
]
# Update ignore list
if hasattr(quant_config, "ignore"):
mistral_ignore = []
for name in quant_config.ignore:
mistral_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
mistral_name = re.sub(pattern, repl, name)
mistral_ignore.append(mistral_name)
quant_config.ignore = mistral_ignore
# Update target list
if hasattr(quant_config, "config_groups"):
config_groups = quant_config.config_groups
for group_name in config_groups:
if "targets" in config_groups[group_name]:
targets = []
for name in config_groups[group_name]["targets"]:
mistral_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
mistral_name = re.sub(pattern, repl, name)
targets.append(mistral_name)
config_groups[group_name]["targets"] = targets
quant_config.config_groups = config_groups
return quant_config
class AudioLanguageAdapter(nn.Module):