mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 20:29:41 +08:00
[Ultravox] Fix Gemma instantiation, support quantization via --hf-overrides (#24131)
Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
parent
6c8deacd72
commit
f17a6aa4ec
@ -1091,11 +1091,11 @@ class ModelConfig:
|
||||
|
||||
assert_never(runner_type)
|
||||
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
|
||||
quant_cfg = getattr(hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
# compressed-tensors uses a "compression_config" key
|
||||
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
||||
quant_cfg = getattr(hf_config, "compression_config", None)
|
||||
|
||||
else:
|
||||
# Set quant_method for ModelOpt models.
|
||||
@ -1136,7 +1136,11 @@ class ModelConfig:
|
||||
self.quantization)
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
quant_cfg = self._parse_quant_hf_config(self.hf_config)
|
||||
if quant_cfg is None and (text_config := getattr(
|
||||
self.hf_config, "text_config", None)):
|
||||
# Check the text config as well for multi-modal models.
|
||||
quant_cfg = self._parse_quant_hf_config(text_config)
|
||||
|
||||
if quant_cfg is not None:
|
||||
# Use the community standard 'quant_method'
|
||||
|
||||
@ -276,7 +276,7 @@ class UltravoxProjector(nn.Module):
|
||||
else:
|
||||
self.act = get_act_fn(config.projector_act)
|
||||
|
||||
dim_out = config.text_hidden_size
|
||||
dim_out = config.text_config.hidden_size
|
||||
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
|
||||
|
||||
# Ultravox v0.4.1 and below use layer_norm after the second linear layer
|
||||
|
||||
@ -20,10 +20,13 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
|
||||
Args:
|
||||
audio_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
Custom audio config or dict
|
||||
Custom audio config or dict.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
The config object of the text backbone. Can be any of `LlamaConfig`
|
||||
or `MistralConfig`.
|
||||
The config object of the text backbone.
|
||||
audio_model_id (`str`, *optional*):
|
||||
The model ID of the audio backbone.
|
||||
text_model_id (`str`, *optional*):
|
||||
The model ID of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
audio_token_index (`int`, *optional*, defaults to 32000):
|
||||
@ -60,15 +63,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
stack_factor: int = 8,
|
||||
norm_init: float = 0.4,
|
||||
projector_act: str = "swiglu",
|
||||
text_model_lora_config: Optional[dict[str, Any]] = None,
|
||||
audio_model_lora_config: Optional[dict[str, Any]] = None,
|
||||
projector_ln_mid: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
self.audio_model_id = audio_model_id
|
||||
self.text_model_id = text_model_id
|
||||
self.audio_token_index = audio_token_index
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
@ -77,36 +75,47 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
self.projector_act = projector_act
|
||||
self.projector_ln_mid = projector_ln_mid
|
||||
|
||||
if text_model_id is not None:
|
||||
# Avoid circular import
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
text_config_obj = get_config(text_model_id,
|
||||
trust_remote_code=False)
|
||||
else:
|
||||
# N.B. May set the wrapped_model_config below.
|
||||
self.text_model_id = text_model_id
|
||||
if text_model_id is None:
|
||||
text_config = text_config or {}
|
||||
text_config_obj = transformers.CONFIG_MAPPING[text_config.get(
|
||||
"model_type", "llama")](**text_config)
|
||||
self.wrapped_model_config = transformers.CONFIG_MAPPING[
|
||||
text_config.get("model_type", "llama")](**text_config)
|
||||
|
||||
inner_text_config = text_config_obj.get_text_config()
|
||||
|
||||
if audio_model_id is not None:
|
||||
# Avoid circular import
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
audio_config = get_config(audio_model_id, trust_remote_code=False)
|
||||
else:
|
||||
# N.B. May set the audio_config below.
|
||||
self.audio_model_id = audio_model_id
|
||||
if audio_model_id is None:
|
||||
self.audio_model_id = None
|
||||
audio_config = audio_config or {}
|
||||
audio_config = transformers.CONFIG_MAPPING[audio_config.get(
|
||||
self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
|
||||
"model_type", "whisper")](**audio_config)
|
||||
|
||||
self.text_config = text_config_obj
|
||||
self.audio_config = audio_config
|
||||
self.text_model_lora_config = text_model_lora_config or {}
|
||||
self.audio_model_lora_config = audio_model_lora_config or {}
|
||||
|
||||
self.vocab_size = inner_text_config.vocab_size
|
||||
self.initializer_range = inner_text_config.initializer_range
|
||||
self.text_hidden_size = inner_text_config.hidden_size
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
# Since --hf-overrides are applied _after_ the UltravoxConfig is
|
||||
# instantiated, load the configs implicitly when assigning text_model_id
|
||||
# or audio_model_id. This allows:
|
||||
#
|
||||
# --hf-overrides.text_model_id=<quantized variant>
|
||||
#
|
||||
# to behave as intended.
|
||||
if key == "text_model_id" and value is not None:
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
self.wrapped_model_config = get_config(value,
|
||||
trust_remote_code=False)
|
||||
elif key == "audio_model_id" and value is not None:
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
self.audio_config = get_config(value, trust_remote_code=False)
|
||||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
@property
|
||||
def text_config(self) -> Optional[transformers.PretrainedConfig]:
|
||||
# When Ultravox wraps a multi-modal model (e.g. Gemma), we instantiate
|
||||
# the full model, but the text config is the text config of the inner
|
||||
# model.
|
||||
return (self.wrapped_model_config.get_text_config()
|
||||
if self.wrapped_model_config else None)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user