mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:15:01 +08:00
[Model] Refactor Phi-4-multimodal to use merged processor and support V1 (#15477)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d9737ca1c6
commit
83f3c3bd91
@ -1004,7 +1004,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* `microsoft/Phi-4-multimodal-instruct`, etc.
|
* `microsoft/Phi-4-multimodal-instruct`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
*
|
*
|
||||||
*
|
* ✅︎
|
||||||
- * `PixtralForConditionalGeneration`
|
- * `PixtralForConditionalGeneration`
|
||||||
* Pixtral
|
* Pixtral
|
||||||
* T + I<sup>+</sup>
|
* T + I<sup>+</sup>
|
||||||
|
|||||||
@ -89,7 +89,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
|
|||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model_path,
|
model=model_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=12800,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_lora_rank=320,
|
max_lora_rank=320,
|
||||||
|
|||||||
@ -814,10 +814,13 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model_path,
|
model=model_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=5120,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
|
max_num_batched_tokens=12800,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_lora_rank=320,
|
max_lora_rank=320,
|
||||||
|
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||||
|
mm_processor_kwargs={"dynamic_hd": 16},
|
||||||
limit_mm_per_prompt={"image": 1},
|
limit_mm_per_prompt={"image": 1},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -503,11 +503,13 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model_path,
|
model=model_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=10000,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
limit_mm_per_prompt={"image": len(image_urls)},
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_lora_rank=320,
|
max_lora_rank=320,
|
||||||
|
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||||
|
mm_processor_kwargs={"dynamic_hd": 4},
|
||||||
)
|
)
|
||||||
|
|
||||||
placeholders = "".join(f"<|image_{i}|>"
|
placeholders = "".join(f"<|image_{i}|>"
|
||||||
|
|||||||
@ -18,6 +18,7 @@ transformers
|
|||||||
mistral_common >= 1.5.4
|
mistral_common >= 1.5.4
|
||||||
aiohttp
|
aiohttp
|
||||||
starlette
|
starlette
|
||||||
|
scipy
|
||||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||||
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||||
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.audio import resample_audio
|
from vllm.multimodal.audio import resample_audio_librosa
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
from ....conftest import HfRunner, VllmRunner
|
from ....conftest import HfRunner, VllmRunner
|
||||||
@ -43,6 +43,18 @@ def audio(request):
|
|||||||
return AudioAsset(request.param)
|
return AudioAsset(request.param)
|
||||||
|
|
||||||
|
|
||||||
|
def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
|
||||||
|
"""Convert kwargs to CLI args."""
|
||||||
|
args = []
|
||||||
|
for key, value in params_kwargs.items():
|
||||||
|
if isinstance(value, bool):
|
||||||
|
if value:
|
||||||
|
args.append(f"--{key.replace('_','-')}")
|
||||||
|
else:
|
||||||
|
args.append(f"--{key.replace('_','-')}={value}")
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[
|
@pytest.fixture(params=[
|
||||||
pytest.param({}, marks=pytest.mark.cpu_model),
|
pytest.param({}, marks=pytest.mark.cpu_model),
|
||||||
pytest.param(CHUNKED_PREFILL_KWARGS),
|
pytest.param(CHUNKED_PREFILL_KWARGS),
|
||||||
@ -52,10 +64,7 @@ def server(request, audio_assets):
|
|||||||
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
|
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
|
||||||
"--limit-mm-per-prompt",
|
"--limit-mm-per-prompt",
|
||||||
json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
|
json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
|
||||||
] + [
|
] + params_kwargs_to_cli_args(request.param)
|
||||||
f"--{key.replace('_','-')}={value}"
|
|
||||||
for key, value in request.param.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME,
|
with RemoteOpenAIServer(MODEL_NAME,
|
||||||
args,
|
args,
|
||||||
@ -136,9 +145,9 @@ def run_test(
|
|||||||
[hf_prompt],
|
[hf_prompt],
|
||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
audios=[(resample_audio(audio[0],
|
audios=[(resample_audio_librosa(audio[0],
|
||||||
orig_sr=audio[1],
|
orig_sr=audio[1],
|
||||||
target_sr=16000), 16000)])
|
target_sr=16000), 16000)])
|
||||||
for _, hf_prompt, audio in prompts_and_audios
|
for _, hf_prompt, audio in prompts_and_audios
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -181,7 +181,7 @@ def run_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
@pytest.mark.parametrize("max_model_len", [4096])
|
@pytest.mark.parametrize("max_model_len", [12800])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||||
@ -225,7 +225,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
@pytest.mark.parametrize("max_model_len", [10000])
|
@pytest.mark.parametrize("max_model_len", [25600])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||||
@ -258,7 +258,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
@pytest.mark.parametrize("max_model_len", [10000])
|
@pytest.mark.parametrize("max_model_len", [12800])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
|
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
|
||||||
|
|||||||
@ -274,6 +274,7 @@ def _test_processing_correctness_mistral(
|
|||||||
"nvidia/NVLM-D-72B",
|
"nvidia/NVLM-D-72B",
|
||||||
"google/paligemma-3b-mix-224",
|
"google/paligemma-3b-mix-224",
|
||||||
"google/paligemma2-3b-ft-docci-448",
|
"google/paligemma2-3b-ft-docci-448",
|
||||||
|
"microsoft/Phi-4-multimodal-instruct",
|
||||||
"mistralai/Pixtral-12B-2409",
|
"mistralai/Pixtral-12B-2409",
|
||||||
"mistral-community/pixtral-12b",
|
"mistral-community/pixtral-12b",
|
||||||
"Qwen/Qwen-VL-Chat",
|
"Qwen/Qwen-VL-Chat",
|
||||||
|
|||||||
59
tests/models/multimodal/processing/test_phi4mm.py
Normal file
59
tests/models/multimodal/processing/test_phi4mm.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Tests for phi4mm's multimodal preprocessing kwargs."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
from ....conftest import _ImageAssets
|
||||||
|
from ...utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"])
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("mm_processor_kwargs", "expected_toks_per_img"),
|
||||||
|
[
|
||||||
|
({"dynamic_hd": 4}, 1329),
|
||||||
|
({"dynamic_hd": 16}, 4433),
|
||||||
|
# the default num_crops of phi-4-multimodal is 36
|
||||||
|
({}, 9585),
|
||||||
|
])
|
||||||
|
# yapf: enable
|
||||||
|
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||||
|
@pytest.mark.parametrize("kwargs_on_init", [True, False])
|
||||||
|
def test_processor_override(
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model_id: str,
|
||||||
|
mm_processor_kwargs: dict[str, int],
|
||||||
|
expected_toks_per_img: int,
|
||||||
|
num_imgs: int,
|
||||||
|
kwargs_on_init: bool,
|
||||||
|
):
|
||||||
|
"""Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly."""
|
||||||
|
# Avoid initializing CUDA early
|
||||||
|
from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID
|
||||||
|
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_id,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
|
||||||
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
|
)
|
||||||
|
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||||
|
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
|
||||||
|
|
||||||
|
# Build the image str / prompt based on the number of images we pass
|
||||||
|
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
|
||||||
|
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
||||||
|
|
||||||
|
image_size = ctx.get_hf_config(
|
||||||
|
).embd_layer["image_embd_layer"]["crop_size"]
|
||||||
|
dummy_image_size = (image_size * 7, image_size * 7)
|
||||||
|
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
|
||||||
|
mm_data = {"image": [dummy_image] * num_imgs}
|
||||||
|
|
||||||
|
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||||
|
|
||||||
|
# Ensure we have the right number of placeholders per num_crops size
|
||||||
|
img_tok_count = processed_inputs["prompt_token_ids"].count(
|
||||||
|
_IMAGE_PLACEHOLDER_TOKEN_ID)
|
||||||
|
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||||
@ -482,11 +482,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
if modality in ("image", "image_embeds"):
|
if modality in ("image", "image_embeds"):
|
||||||
if model_type == "chatglm":
|
if model_type == "chatglm":
|
||||||
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||||
if model_type == "phi3_v":
|
if model_type in ("phi3_v", "phi4mm"):
|
||||||
# Workaround since this token is not defined in the tokenizer
|
|
||||||
return f"<|image_{current_count}|>"
|
return f"<|image_{current_count}|>"
|
||||||
if model_type == "phi4mm":
|
|
||||||
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
|
|
||||||
if model_type in ("minicpmo", "minicpmv"):
|
if model_type in ("minicpmo", "minicpmv"):
|
||||||
return "(<image>./</image>)"
|
return "(<image>./</image>)"
|
||||||
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
|
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
|
||||||
@ -522,7 +519,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
if model_type == "ultravox":
|
if model_type == "ultravox":
|
||||||
return "<|audio|>"
|
return "<|audio|>"
|
||||||
if model_type == "phi4mm":
|
if model_type == "phi4mm":
|
||||||
return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
|
return f"<|audio_{current_count}|>"
|
||||||
if model_type in ("qwen2_audio", "qwen2_5_omni"):
|
if model_type in ("qwen2_audio", "qwen2_5_omni"):
|
||||||
return (f"Audio {current_count}: "
|
return (f"Audio {current_count}: "
|
||||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
|
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
|
||||||
|
|||||||
@ -327,7 +327,7 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
|
|||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
processor: Optional[ProcessorMixin],
|
processor: Optional[ProcessorMixin] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
if processor is None:
|
if processor is None:
|
||||||
processor = self.get_hf_processor()
|
processor = self.get_hf_processor()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1159,8 +1159,11 @@ class AudioEmbedding(nn.Module):
|
|||||||
input_embeds: torch.FloatTensor,
|
input_embeds: torch.FloatTensor,
|
||||||
audio_attention_mask: torch.Tensor = None,
|
audio_attention_mask: torch.Tensor = None,
|
||||||
audio_projection_mode: str = "speech",
|
audio_projection_mode: str = "speech",
|
||||||
):
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
arguments:
|
||||||
|
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
||||||
|
"""
|
||||||
if self.freeze_audio_processor:
|
if self.freeze_audio_processor:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
audio_features, masks = self.encoder(input_embeds,
|
audio_features, masks = self.encoder(input_embeds,
|
||||||
@ -1210,62 +1213,20 @@ class AudioEmbedding(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
audio_features: torch.FloatTensor,
|
||||||
input_embeds: torch.FloatTensor,
|
audio_attention_mask: torch.Tensor = None,
|
||||||
audio_embed_sizes,
|
audio_projection_mode: str = "speech",
|
||||||
**kwargs,
|
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
arguments:
|
arguments:
|
||||||
input_ids: input text ids (B, U)
|
audio_features: audio features (T, D)
|
||||||
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
|
||||||
|
returns:
|
||||||
|
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
|
||||||
"""
|
"""
|
||||||
assert input_embeds is not None and len(input_embeds) == len(
|
audio_embeds = self.get_audio_features(
|
||||||
audio_embed_sizes)
|
audio_features.unsqueeze(0),
|
||||||
|
audio_attention_mask=audio_attention_mask,
|
||||||
input_shape = input_ids.size()
|
audio_projection_mode=audio_projection_mode,
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
)
|
||||||
|
return audio_embeds.squeeze(0)
|
||||||
with torch.no_grad():
|
|
||||||
positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(
|
|
||||||
as_tuple=False)
|
|
||||||
|
|
||||||
if not isinstance(input_embeds, list):
|
|
||||||
input_embeds = [input_embeds]
|
|
||||||
|
|
||||||
audio_projection_mode = kwargs.get("audio_projection_mode", "speech")
|
|
||||||
audio_set_tensor = [
|
|
||||||
self.get_audio_features(
|
|
||||||
input_embed, audio_projection_mode=audio_projection_mode)
|
|
||||||
for input_embed in input_embeds
|
|
||||||
]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
|
|
||||||
|
|
||||||
if "wte" in kwargs:
|
|
||||||
# we use the token embedding layer from the huggingface model, this
|
|
||||||
# is REQUIRED to make sure we are using the loaded weights.
|
|
||||||
hidden_states = kwargs["wte"](input_ids)
|
|
||||||
else:
|
|
||||||
# otherwise, we use token embedding in pretrained mixformer from
|
|
||||||
# phi team
|
|
||||||
hidden_states = self.wte(input_ids)
|
|
||||||
|
|
||||||
if len(positions.tolist()) > 0:
|
|
||||||
assert sum(audio_embed_sizes) == len(
|
|
||||||
positions
|
|
||||||
), "please ensure the encoder outputs have the same length as"\
|
|
||||||
" defined in input_ids!"
|
|
||||||
idx = 0
|
|
||||||
for i in range(len(audio_embed_sizes)):
|
|
||||||
cnt = audio_embed_sizes[i]
|
|
||||||
assert audio_set_tensor[i].shape[0] == 1
|
|
||||||
hidden_states[
|
|
||||||
positions[idx, 0],
|
|
||||||
positions[idx, 1]:positions[idx, 1] + cnt,
|
|
||||||
] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to(
|
|
||||||
hidden_states.dtype).to(hidden_states.device))
|
|
||||||
idx += cnt
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@ -43,7 +43,7 @@ class AudioPlugin(MultiModalPlugin):
|
|||||||
"There is no default maximum multimodal tokens")
|
"There is no default maximum multimodal tokens")
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
def resample_audio_librosa(
|
||||||
audio: npt.NDArray[np.floating],
|
audio: npt.NDArray[np.floating],
|
||||||
*,
|
*,
|
||||||
orig_sr: float,
|
orig_sr: float,
|
||||||
@ -52,6 +52,55 @@ def resample_audio(
|
|||||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio_scipy(
|
||||||
|
audio: npt.NDArray[np.floating],
|
||||||
|
*,
|
||||||
|
orig_sr: float,
|
||||||
|
target_sr: float,
|
||||||
|
):
|
||||||
|
# lazy import scipy.signal, otherwise it will crash doc build.
|
||||||
|
import scipy.signal
|
||||||
|
|
||||||
|
if orig_sr > target_sr:
|
||||||
|
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
|
||||||
|
elif orig_sr < target_sr:
|
||||||
|
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
class AudioResampler:
|
||||||
|
"""Resample audio data to a target sample rate."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target_sr: Optional[float] = None,
|
||||||
|
method: Literal["librosa", "scipy"] = "librosa",
|
||||||
|
):
|
||||||
|
self.target_sr = target_sr
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
def resample(
|
||||||
|
self,
|
||||||
|
audio: npt.NDArray[np.floating],
|
||||||
|
*,
|
||||||
|
orig_sr: float,
|
||||||
|
) -> npt.NDArray[np.floating]:
|
||||||
|
if self.target_sr is None:
|
||||||
|
raise RuntimeError("Audio resampling is not supported when "
|
||||||
|
"`target_sr` is not provided")
|
||||||
|
if self.method == "librosa":
|
||||||
|
return resample_audio_librosa(audio,
|
||||||
|
orig_sr=orig_sr,
|
||||||
|
target_sr=self.target_sr)
|
||||||
|
elif self.method == "scipy":
|
||||||
|
return resample_audio_scipy(audio,
|
||||||
|
orig_sr=orig_sr,
|
||||||
|
target_sr=self.target_sr)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid resampling method: {self.method}. "
|
||||||
|
"Supported methods are 'librosa' and 'scipy'.")
|
||||||
|
|
||||||
|
|
||||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||||
|
|
||||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||||
|
|||||||
@ -3,8 +3,8 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from collections.abc import Callable, Iterator, Mapping, Sequence
|
from collections.abc import Callable, Iterator, Mapping, Sequence
|
||||||
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
|
from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
|
||||||
Union)
|
TypeVar, Union)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -14,7 +14,7 @@ from typing_extensions import TypeAlias, TypeGuard, assert_never
|
|||||||
|
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .audio import resample_audio
|
from .audio import AudioResampler
|
||||||
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
||||||
ImageItem, ModalityData, MultiModalDataDict,
|
ImageItem, ModalityData, MultiModalDataDict,
|
||||||
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
|
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
|
||||||
@ -308,10 +308,18 @@ class MultiModalDataParser:
|
|||||||
items to the model's expected sampling rate.
|
items to the model's expected sampling rate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, target_sr: Optional[float] = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
target_sr: Optional[float] = None,
|
||||||
|
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.target_sr = target_sr
|
self.audio_resampler = AudioResampler(
|
||||||
|
target_sr=target_sr,
|
||||||
|
method=audio_resample_method,
|
||||||
|
)
|
||||||
|
|
||||||
def _is_embeddings(
|
def _is_embeddings(
|
||||||
self, data: object
|
self, data: object
|
||||||
@ -374,15 +382,8 @@ class MultiModalDataParser:
|
|||||||
if orig_sr is None:
|
if orig_sr is None:
|
||||||
new_audio = audio
|
new_audio = audio
|
||||||
else:
|
else:
|
||||||
target_sr = self.target_sr
|
new_audio = self.audio_resampler.resample(audio,
|
||||||
if target_sr is None:
|
orig_sr=orig_sr)
|
||||||
raise RuntimeError(
|
|
||||||
"Audio resampling is not supported when "
|
|
||||||
"`target_sr` is not provided")
|
|
||||||
|
|
||||||
new_audio = resample_audio(audio,
|
|
||||||
orig_sr=orig_sr,
|
|
||||||
target_sr=target_sr)
|
|
||||||
|
|
||||||
new_audios.append(new_audio)
|
new_audios.append(new_audio)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user