[Model] Refactor Phi-4-multimodal to use merged processor and support V1 (#15477)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py 2025-04-19 17:26:11 +08:00 committed by GitHub
parent d9737ca1c6
commit 83f3c3bd91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 818 additions and 1246 deletions

View File

@ -1004,7 +1004,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `microsoft/Phi-4-multimodal-instruct`, etc.
* ✅︎
*
*
* ✅︎
- * `PixtralForConditionalGeneration`
* Pixtral
* T + I<sup>+</sup>

View File

@ -89,7 +89,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
max_model_len=12800,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,

View File

@ -814,10 +814,13 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
max_model_len=5120,
max_num_seqs=2,
max_num_batched_tokens=12800,
enable_lora=True,
max_lora_rank=320,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"dynamic_hd": 16},
limit_mm_per_prompt={"image": 1},
)

View File

@ -503,11 +503,13 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=10000,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
enable_lora=True,
max_lora_rank=320,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"dynamic_hd": 4},
)
placeholders = "".join(f"<|image_{i}|>"

View File

@ -18,6 +18,7 @@ transformers
mistral_common >= 1.5.4
aiohttp
starlette
scipy
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

View File

@ -1,14 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Optional
from typing import Any, Optional
import numpy as np
import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer
from vllm.multimodal.audio import resample_audio
from vllm.multimodal.audio import resample_audio_librosa
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner
@ -43,6 +43,18 @@ def audio(request):
return AudioAsset(request.param)
def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
"""Convert kwargs to CLI args."""
args = []
for key, value in params_kwargs.items():
if isinstance(value, bool):
if value:
args.append(f"--{key.replace('_','-')}")
else:
args.append(f"--{key.replace('_','-')}={value}")
return args
@pytest.fixture(params=[
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
@ -52,10 +64,7 @@ def server(request, audio_assets):
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
]
] + params_kwargs_to_cli_args(request.param)
with RemoteOpenAIServer(MODEL_NAME,
args,
@ -136,9 +145,9 @@ def run_test(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[(resample_audio(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
audios=[(resample_audio_librosa(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]

View File

@ -181,7 +181,7 @@ def run_test(
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [4096])
@pytest.mark.parametrize("max_model_len", [12800])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
@ -225,7 +225,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [10000])
@pytest.mark.parametrize("max_model_len", [25600])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
@ -258,7 +258,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [10000])
@pytest.mark.parametrize("max_model_len", [12800])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,

View File

@ -274,6 +274,7 @@ def _test_processing_correctness_mistral(
"nvidia/NVLM-D-72B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat",

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for phi4mm's multimodal preprocessing kwargs."""
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"])
# yapf: disable
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
[
({"dynamic_hd": 4}, 1329),
({"dynamic_hd": 16}, 4433),
# the default num_crops of phi-4-multimodal is 36
({}, 9585),
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
image_size = ctx.get_hf_config(
).embd_layer["image_embd_layer"]["crop_size"]
dummy_image_size = (image_size * 7, image_size * 7)
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count(
_IMAGE_PLACEHOLDER_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs

View File

@ -482,11 +482,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if modality in ("image", "image_embeds"):
if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
if model_type in ("phi3_v", "phi4mm"):
return f"<|image_{current_count}|>"
if model_type == "phi4mm":
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
@ -522,7 +519,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type == "ultravox":
return "<|audio|>"
if model_type == "phi4mm":
return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
return f"<|audio_{current_count}|>"
if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")

View File

@ -327,7 +327,7 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin],
processor: Optional[ProcessorMixin] = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()

File diff suppressed because it is too large Load Diff

View File

@ -1159,8 +1159,11 @@ class AudioEmbedding(nn.Module):
input_embeds: torch.FloatTensor,
audio_attention_mask: torch.Tensor = None,
audio_projection_mode: str = "speech",
):
) -> torch.FloatTensor:
"""
arguments:
input_embeds: audio features (B, T, D) B: num audios in a sequence
"""
if self.freeze_audio_processor:
with torch.no_grad():
audio_features, masks = self.encoder(input_embeds,
@ -1210,62 +1213,20 @@ class AudioEmbedding(nn.Module):
def forward(
self,
input_ids: torch.LongTensor,
input_embeds: torch.FloatTensor,
audio_embed_sizes,
**kwargs,
audio_features: torch.FloatTensor,
audio_attention_mask: torch.Tensor = None,
audio_projection_mode: str = "speech",
) -> torch.FloatTensor:
"""
arguments:
input_ids: input text ids (B, U)
input_embeds: audio features (B, T, D) B: num audios in a sequence
audio_features: audio features (T, D)
returns:
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
"""
assert input_embeds is not None and len(input_embeds) == len(
audio_embed_sizes)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
with torch.no_grad():
positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(
as_tuple=False)
if not isinstance(input_embeds, list):
input_embeds = [input_embeds]
audio_projection_mode = kwargs.get("audio_projection_mode", "speech")
audio_set_tensor = [
self.get_audio_features(
input_embed, audio_projection_mode=audio_projection_mode)
for input_embed in input_embeds
]
with torch.no_grad():
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
if "wte" in kwargs:
# we use the token embedding layer from the huggingface model, this
# is REQUIRED to make sure we are using the loaded weights.
hidden_states = kwargs["wte"](input_ids)
else:
# otherwise, we use token embedding in pretrained mixformer from
# phi team
hidden_states = self.wte(input_ids)
if len(positions.tolist()) > 0:
assert sum(audio_embed_sizes) == len(
positions
), "please ensure the encoder outputs have the same length as"\
" defined in input_ids!"
idx = 0
for i in range(len(audio_embed_sizes)):
cnt = audio_embed_sizes[i]
assert audio_set_tensor[i].shape[0] == 1
hidden_states[
positions[idx, 0],
positions[idx, 1]:positions[idx, 1] + cnt,
] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to(
hidden_states.dtype).to(hidden_states.device))
idx += cnt
return hidden_states
audio_embeds = self.get_audio_features(
audio_features.unsqueeze(0),
audio_attention_mask=audio_attention_mask,
audio_projection_mode=audio_projection_mode,
)
return audio_embeds.squeeze(0)

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import base64
from io import BytesIO
from pathlib import Path
from typing import Literal, Optional
import numpy as np
import numpy.typing as npt
@ -43,7 +43,7 @@ class AudioPlugin(MultiModalPlugin):
"There is no default maximum multimodal tokens")
def resample_audio(
def resample_audio_librosa(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
@ -52,6 +52,55 @@ def resample_audio(
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
def resample_audio_scipy(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
):
# lazy import scipy.signal, otherwise it will crash doc build.
import scipy.signal
if orig_sr > target_sr:
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
elif orig_sr < target_sr:
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
return audio
class AudioResampler:
"""Resample audio data to a target sample rate."""
def __init__(
self,
target_sr: Optional[float] = None,
method: Literal["librosa", "scipy"] = "librosa",
):
self.target_sr = target_sr
self.method = method
def resample(
self,
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
) -> npt.NDArray[np.floating]:
if self.target_sr is None:
raise RuntimeError("Audio resampling is not supported when "
"`target_sr` is not provided")
if self.method == "librosa":
return resample_audio_librosa(audio,
orig_sr=orig_sr,
target_sr=self.target_sr)
elif self.method == "scipy":
return resample_audio_scipy(audio,
orig_sr=orig_sr,
target_sr=self.target_sr)
else:
raise ValueError(f"Invalid resampling method: {self.method}. "
"Supported methods are 'librosa' and 'scipy'.")
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:

View File

@ -3,8 +3,8 @@
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
Union)
from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
TypeVar, Union)
import numpy as np
import torch
@ -14,7 +14,7 @@ from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of
from .audio import resample_audio
from .audio import AudioResampler
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
@ -308,10 +308,18 @@ class MultiModalDataParser:
items to the model's expected sampling rate.
"""
def __init__(self, *, target_sr: Optional[float] = None) -> None:
def __init__(
self,
*,
target_sr: Optional[float] = None,
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
) -> None:
super().__init__()
self.target_sr = target_sr
self.audio_resampler = AudioResampler(
target_sr=target_sr,
method=audio_resample_method,
)
def _is_embeddings(
self, data: object
@ -374,15 +382,8 @@ class MultiModalDataParser:
if orig_sr is None:
new_audio = audio
else:
target_sr = self.target_sr
if target_sr is None:
raise RuntimeError(
"Audio resampling is not supported when "
"`target_sr` is not provided")
new_audio = resample_audio(audio,
orig_sr=orig_sr,
target_sr=target_sr)
new_audio = self.audio_resampler.resample(audio,
orig_sr=orig_sr)
new_audios.append(new_audio)