From f17a6aa4ec3462ad812331259c527388eb09eb0d Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 10 Sep 2025 22:25:34 -0700 Subject: [PATCH] [Ultravox] Fix Gemma instantiation, support quantization via --hf-overrides (#24131) Signed-off-by: Peter Salas --- vllm/config/__init__.py | 12 ++-- vllm/model_executor/models/ultravox.py | 2 +- vllm/transformers_utils/configs/ultravox.py | 79 ++++++++++++--------- 3 files changed, 53 insertions(+), 40 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 538d44d5337cc..3f63bd2dcf416 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1091,11 +1091,11 @@ class ModelConfig: assert_never(runner_type) - def _parse_quant_hf_config(self): - quant_cfg = getattr(self.hf_config, "quantization_config", None) + def _parse_quant_hf_config(self, hf_config: PretrainedConfig): + quant_cfg = getattr(hf_config, "quantization_config", None) if quant_cfg is None: # compressed-tensors uses a "compression_config" key - quant_cfg = getattr(self.hf_config, "compression_config", None) + quant_cfg = getattr(hf_config, "compression_config", None) else: # Set quant_method for ModelOpt models. @@ -1136,7 +1136,11 @@ class ModelConfig: self.quantization) # Parse quantization method from the HF model config, if available. - quant_cfg = self._parse_quant_hf_config() + quant_cfg = self._parse_quant_hf_config(self.hf_config) + if quant_cfg is None and (text_config := getattr( + self.hf_config, "text_config", None)): + # Check the text config as well for multi-modal models. + quant_cfg = self._parse_quant_hf_config(text_config) if quant_cfg is not None: # Use the community standard 'quant_method' diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c883065805279..9885309035e6c 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -276,7 +276,7 @@ class UltravoxProjector(nn.Module): else: self.act = get_act_fn(config.projector_act) - dim_out = config.text_hidden_size + dim_out = config.text_config.hidden_size self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) # Ultravox v0.4.1 and below use layer_norm after the second linear layer diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 87064cc12deda..71266b9327369 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -20,10 +20,13 @@ class UltravoxConfig(transformers.PretrainedConfig): Args: audio_config (`Union[AutoConfig, dict]`, *optional*): - Custom audio config or dict + Custom audio config or dict. text_config (`Union[AutoConfig, dict]`, *optional*): - The config object of the text backbone. Can be any of `LlamaConfig` - or `MistralConfig`. + The config object of the text backbone. + audio_model_id (`str`, *optional*): + The model ID of the audio backbone. + text_model_id (`str`, *optional*): + The model ID of the text backbone. ignore_index (`int`, *optional*, defaults to -100): The ignore index for the loss function. audio_token_index (`int`, *optional*, defaults to 32000): @@ -60,15 +63,10 @@ class UltravoxConfig(transformers.PretrainedConfig): stack_factor: int = 8, norm_init: float = 0.4, projector_act: str = "swiglu", - text_model_lora_config: Optional[dict[str, Any]] = None, - audio_model_lora_config: Optional[dict[str, Any]] = None, projector_ln_mid: bool = False, **kwargs, ): self.ignore_index = ignore_index - - self.audio_model_id = audio_model_id - self.text_model_id = text_model_id self.audio_token_index = audio_token_index self.hidden_size = hidden_size @@ -77,36 +75,47 @@ class UltravoxConfig(transformers.PretrainedConfig): self.projector_act = projector_act self.projector_ln_mid = projector_ln_mid - if text_model_id is not None: - # Avoid circular import - from vllm.transformers_utils.config import get_config - - text_config_obj = get_config(text_model_id, - trust_remote_code=False) - else: + # N.B. May set the wrapped_model_config below. + self.text_model_id = text_model_id + if text_model_id is None: text_config = text_config or {} - text_config_obj = transformers.CONFIG_MAPPING[text_config.get( - "model_type", "llama")](**text_config) + self.wrapped_model_config = transformers.CONFIG_MAPPING[ + text_config.get("model_type", "llama")](**text_config) - inner_text_config = text_config_obj.get_text_config() - - if audio_model_id is not None: - # Avoid circular import - from vllm.transformers_utils.config import get_config - - audio_config = get_config(audio_model_id, trust_remote_code=False) - else: + # N.B. May set the audio_config below. + self.audio_model_id = audio_model_id + if audio_model_id is None: + self.audio_model_id = None audio_config = audio_config or {} - audio_config = transformers.CONFIG_MAPPING[audio_config.get( + self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( "model_type", "whisper")](**audio_config) - self.text_config = text_config_obj - self.audio_config = audio_config - self.text_model_lora_config = text_model_lora_config or {} - self.audio_model_lora_config = audio_model_lora_config or {} - - self.vocab_size = inner_text_config.vocab_size - self.initializer_range = inner_text_config.initializer_range - self.text_hidden_size = inner_text_config.hidden_size - super().__init__(**kwargs) + + def __setattr__(self, key, value): + # Since --hf-overrides are applied _after_ the UltravoxConfig is + # instantiated, load the configs implicitly when assigning text_model_id + # or audio_model_id. This allows: + # + # --hf-overrides.text_model_id= + # + # to behave as intended. + if key == "text_model_id" and value is not None: + from vllm.transformers_utils.config import get_config + + self.wrapped_model_config = get_config(value, + trust_remote_code=False) + elif key == "audio_model_id" and value is not None: + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(value, trust_remote_code=False) + + return super().__setattr__(key, value) + + @property + def text_config(self) -> Optional[transformers.PretrainedConfig]: + # When Ultravox wraps a multi-modal model (e.g. Gemma), we instantiate + # the full model, but the text config is the text config of the inner + # model. + return (self.wrapped_model_config.get_text_config() + if self.wrapped_model_config else None)