mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Model] Ultravox Model: Support v0.5 Release (#12912)
Signed-off-by: Farzad Abdolhosseini <farzad@fixie.ai>
This commit is contained in:
parent
2ae889052c
commit
08b2d845d6
@ -856,7 +856,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- * `UltravoxModel`
|
- * `UltravoxModel`
|
||||||
* Ultravox
|
* Ultravox
|
||||||
* T + A<sup>E+</sup>
|
* T + A<sup>E+</sup>
|
||||||
* `fixie-ai/ultravox-v0_3`
|
* `fixie-ai/ultravox-v0_5-llama-3_2-1b`
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
|||||||
@ -359,12 +359,12 @@ export VLLM_VIDEO_FETCH_TIMEOUT=<timeout>
|
|||||||
### Audio
|
### Audio
|
||||||
|
|
||||||
Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in).
|
Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in).
|
||||||
Here is a simple example using Ultravox-v0.3.
|
Here is a simple example using Ultravox-v0.5-1B.
|
||||||
|
|
||||||
First, launch the OpenAI-compatible server:
|
First, launch the OpenAI-compatible server:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
vllm serve fixie-ai/ultravox-v0_3
|
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, you can use the OpenAI client as follows:
|
Then, you can use the OpenAI client as follows:
|
||||||
|
|||||||
@ -24,9 +24,9 @@ question_per_audio_count = {
|
|||||||
# Unless specified, these settings have been tested to work on a single L4.
|
# Unless specified, these settings have been tested to work on a single L4.
|
||||||
|
|
||||||
|
|
||||||
# Ultravox 0.3
|
# Ultravox 0.5-1B
|
||||||
def run_ultravox(question: str, audio_count: int):
|
def run_ultravox(question: str, audio_count: int):
|
||||||
model_name = "fixie-ai/ultravox-v0_3"
|
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
messages = [{
|
messages = [{
|
||||||
|
|||||||
@ -12,7 +12,7 @@ vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
|||||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||||
|
|
||||||
(audio inference with Ultravox)
|
(audio inference with Ultravox)
|
||||||
vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096
|
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|||||||
@ -215,7 +215,7 @@ MULTIMODAL_MODELS = {
|
|||||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
|
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||||
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
|
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
|
||||||
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
|
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
|
||||||
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
|
"fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
# TODO: Implement PP
|
# TODO: Implement PP
|
||||||
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
|
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
|
||||||
@ -234,7 +234,7 @@ TEST_MODELS = [
|
|||||||
# [MULTIMODAL GENERATION]
|
# [MULTIMODAL GENERATION]
|
||||||
"OpenGVLab/InternVL2-1B",
|
"OpenGVLab/InternVL2-1B",
|
||||||
"microsoft/Phi-3-vision-128k-instruct",
|
"microsoft/Phi-3-vision-128k-instruct",
|
||||||
"fixie-ai/ultravox-v0_3",
|
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||||
# [LANGUAGE GENERATION - HYBRID ARCH]
|
# [LANGUAGE GENERATION - HYBRID ARCH]
|
||||||
"ai21labs/Jamba-tiny-dev",
|
"ai21labs/Jamba-tiny-dev",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.multimodal.utils import encode_audio_base64, fetch_audio
|
|||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||||
TEST_AUDIO_URLS = [
|
TEST_AUDIO_URLS = [
|
||||||
AudioAsset("winning_call").url,
|
AudioAsset("winning_call").url,
|
||||||
]
|
]
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from ..utils import VLLM_PATH
|
|||||||
EXAMPLES_DIR = VLLM_PATH / "examples"
|
EXAMPLES_DIR = VLLM_PATH / "examples"
|
||||||
|
|
||||||
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||||
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3"
|
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||||
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
|
|||||||
from ....utils import RemoteOpenAIServer
|
from ....utils import RemoteOpenAIServer
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||||
|
|
||||||
AudioTuple = Tuple[np.ndarray, int]
|
AudioTuple = Tuple[np.ndarray, int]
|
||||||
|
|
||||||
|
|||||||
@ -164,7 +164,7 @@ def _test_processing_correctness(
|
|||||||
"Qwen/Qwen2-VL-2B-Instruct",
|
"Qwen/Qwen2-VL-2B-Instruct",
|
||||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||||
"fixie-ai/ultravox-v0_3",
|
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
@pytest.mark.parametrize("num_batches", [32])
|
||||||
|
|||||||
@ -267,7 +267,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||||
min_transformers_version="4.49"), # noqa: E501
|
min_transformers_version="4.49"), # noqa: E501
|
||||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
|
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||||
|
|||||||
@ -258,27 +258,35 @@ class UltravoxProjector(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
||||||
dim = config.audio_config.hidden_size * config.stack_factor
|
dim_in = config.audio_config.hidden_size * config.stack_factor
|
||||||
self.ln_pre = RMSNorm(dim)
|
self.ln_pre = RMSNorm(dim_in)
|
||||||
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
|
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
|
||||||
dim = self.hidden_dim
|
dim_mid = self.hidden_dim
|
||||||
|
|
||||||
if config.projector_act == "swiglu":
|
if config.projector_act == "swiglu":
|
||||||
self.act = MulAndSilu()
|
self.act = MulAndSilu()
|
||||||
dim = dim // 2
|
dim_mid = dim_mid // 2
|
||||||
else:
|
else:
|
||||||
self.act = get_act_fn(config.projector_act)
|
self.act = get_act_fn(config.projector_act)
|
||||||
|
|
||||||
self.linear_2 = nn.Linear(dim,
|
dim_out = config.text_config.hidden_size
|
||||||
config.text_config.hidden_size,
|
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
|
||||||
bias=False)
|
|
||||||
self.ln_post = RMSNorm(config.text_config.hidden_size)
|
# Ultravox v0.4.1 and below use layer_norm after the second linear layer
|
||||||
|
# while v0.5.0 and above uses layer_norm after the first linear layer.
|
||||||
|
if config.projector_ln_mid:
|
||||||
|
self.ln_mid: nn.Module = RMSNorm(dim_mid)
|
||||||
|
self.ln_post = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.ln_mid = nn.Identity()
|
||||||
|
self.ln_post = RMSNorm(dim_out)
|
||||||
|
|
||||||
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||||
audio_features = self._pad_and_stack(audio_features)
|
audio_features = self._pad_and_stack(audio_features)
|
||||||
audio_features = self.ln_pre(audio_features)
|
audio_features = self.ln_pre(audio_features)
|
||||||
hidden_states = self.linear_1(audio_features)
|
hidden_states = self.linear_1(audio_features)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.ln_mid(hidden_states)
|
||||||
hidden_states = self.linear_2(hidden_states)
|
hidden_states = self.linear_2(hidden_states)
|
||||||
hidden_states = self.ln_post(hidden_states)
|
hidden_states = self.ln_post(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@ -37,6 +37,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
|||||||
The LoRA configuration for finetuning the text model.
|
The LoRA configuration for finetuning the text model.
|
||||||
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
|
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
|
||||||
The LoRA configuration for finetuning the audio model.
|
The LoRA configuration for finetuning the audio model.
|
||||||
|
projector_ln_mid (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to apply layer normalization at the middle of the
|
||||||
|
projector or at the end. Versions v0.4.1 and below
|
||||||
|
use `False`, but v0.5 and above use `True`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type = "ultravox"
|
model_type = "ultravox"
|
||||||
@ -56,6 +60,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
|||||||
projector_act: str = "swiglu",
|
projector_act: str = "swiglu",
|
||||||
text_model_lora_config: Optional[Dict[str, Any]] = None,
|
text_model_lora_config: Optional[Dict[str, Any]] = None,
|
||||||
audio_model_lora_config: Optional[Dict[str, Any]] = None,
|
audio_model_lora_config: Optional[Dict[str, Any]] = None,
|
||||||
|
projector_ln_mid: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
@ -68,6 +73,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
|||||||
self.stack_factor = stack_factor
|
self.stack_factor = stack_factor
|
||||||
self.norm_init = norm_init
|
self.norm_init = norm_init
|
||||||
self.projector_act = projector_act
|
self.projector_act = projector_act
|
||||||
|
self.projector_ln_mid = projector_ln_mid
|
||||||
|
|
||||||
if text_model_id is not None:
|
if text_model_id is not None:
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user