mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 17:44:27 +08:00
[Models][Quantization] Add quantization configuration update in Voxtral model (#24122)
Signed-off-by: Alexandre Marques <almarque@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
cc99baf14d
commit
5931b7e5d9
@ -626,9 +626,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
@ -637,12 +636,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
# If using quantized model in mistral format,
|
||||
# quantization scales (qscale_weight) also need to be sliced
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
self.config.num_key_value_heads,
|
||||
self.config.hidden_size)
|
||||
elif "wk" in modules and modules[
|
||||
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads, 1)
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
self.config.num_attention_heads,
|
||||
self.config.hidden_size)
|
||||
elif "wq" in modules and modules[
|
||||
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads, 1)
|
||||
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
|
||||
@ -23,6 +23,7 @@ from transformers.tokenization_utils_base import TextInput
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -327,6 +328,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
super().__init__()
|
||||
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
|
||||
# update quant config to so that ignored module and target module names
|
||||
# match the vLLM model names
|
||||
if hasattr(vllm_config, "quant_config"):
|
||||
vllm_config.quant_config = self.maybe_update_quant_config(
|
||||
vllm_config.quant_config)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
self.downsample_factor = self.config.audio_config.downsample_factor
|
||||
@ -558,6 +565,72 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return loaded_weights
|
||||
|
||||
def maybe_update_quant_config(
|
||||
self, quant_config: QuantizationConfig) -> QuantizationConfig:
|
||||
"""
|
||||
Update quant config to so that ignored module and target module names
|
||||
match the vLLM model names.
|
||||
Right now this is specific for compressed-tensors format and
|
||||
load_format mistral.
|
||||
"""
|
||||
remapping_rules = [
|
||||
(r"output", r"language_model.lm_head"),
|
||||
(r"layers\.(\d+)\.attention\.wo",
|
||||
r"language_model.model.layers.\1.self_attn.out_proj"),
|
||||
(r"layers\.(\d+)\.attention\.w(.*)",
|
||||
r"language_model.model.layers.\1.self_attn.\2_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w1",
|
||||
r"language_model.model.layers.\1.mlp.gate_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w2",
|
||||
r"language_model.model.layers.\1.mlp.down_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w3",
|
||||
r"language_model.model.layers.\1.mlp.up_proj"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
|
||||
),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
|
||||
),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
|
||||
r"whisper_encoder.whisper_encoder.conv1"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
|
||||
r"whisper_encoder.whisper_encoder.conv2"),
|
||||
(r"mm_whisper_embeddings\.audio_language_projection\.0",
|
||||
r"audio_language_adapter.w_in"),
|
||||
(r"mm_whisper_embeddings\.audio_language_projection\.2",
|
||||
r"audio_language_adapter.w_out"),
|
||||
]
|
||||
|
||||
# Update ignore list
|
||||
if hasattr(quant_config, "ignore"):
|
||||
mistral_ignore = []
|
||||
for name in quant_config.ignore:
|
||||
mistral_name = name
|
||||
for pattern, repl in remapping_rules:
|
||||
if re.fullmatch(pattern, name):
|
||||
mistral_name = re.sub(pattern, repl, name)
|
||||
mistral_ignore.append(mistral_name)
|
||||
quant_config.ignore = mistral_ignore
|
||||
|
||||
# Update target list
|
||||
if hasattr(quant_config, "config_groups"):
|
||||
config_groups = quant_config.config_groups
|
||||
for group_name in config_groups:
|
||||
if "targets" in config_groups[group_name]:
|
||||
targets = []
|
||||
for name in config_groups[group_name]["targets"]:
|
||||
mistral_name = name
|
||||
for pattern, repl in remapping_rules:
|
||||
if re.fullmatch(pattern, name):
|
||||
mistral_name = re.sub(pattern, repl, name)
|
||||
targets.append(mistral_name)
|
||||
config_groups[group_name]["targets"] = targets
|
||||
quant_config.config_groups = config_groups
|
||||
|
||||
return quant_config
|
||||
|
||||
|
||||
class AudioLanguageAdapter(nn.Module):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user