mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 02:05:01 +08:00
[Tests] Fixing bug inside MultiModalProfiler. (#21842)
Signed-off-by: Varun Shenoy <varun.vinayak.shenoy@oracle.com>
This commit is contained in:
parent
30ef30ed5a
commit
547795232d
67
tests/models/multimodal/processing/test_mllama4.py
Normal file
67
tests/models/multimodal/processing/test_mllama4.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for mllama's multimodal preprocessing and profiling."""
|
||||||
|
import pytest
|
||||||
|
from torch import prod
|
||||||
|
from transformers import Llama4Config
|
||||||
|
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.profiling import MultiModalProfiler
|
||||||
|
|
||||||
|
from ...utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-Guard-4-12B"])
|
||||||
|
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
|
||||||
|
def test_profiling(model_id: str, max_model_len: int):
|
||||||
|
model_config_kwargs = {
|
||||||
|
"max_model_len": max_model_len,
|
||||||
|
}
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_id,
|
||||||
|
model_config_kwargs=model_config_kwargs,
|
||||||
|
limit_mm_per_prompt={"image": 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_config = ctx.get_mm_config()
|
||||||
|
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||||
|
profiler = MultiModalProfiler(processor)
|
||||||
|
|
||||||
|
decoder_dummy_data = profiler.get_decoder_dummy_data(
|
||||||
|
max_model_len,
|
||||||
|
mm_counts=mm_config.limit_per_prompt,
|
||||||
|
)
|
||||||
|
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
|
||||||
|
max_model_len,
|
||||||
|
mm_counts=mm_config.limit_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_config = ctx.get_hf_config(Llama4Config)
|
||||||
|
|
||||||
|
mm_kwargs = processor.apply(
|
||||||
|
prompt=dummy_mm_data.prompt,
|
||||||
|
mm_data=dummy_mm_data.mm_data,
|
||||||
|
hf_processor_mm_kwargs=dict(),
|
||||||
|
)["mm_kwargs"]
|
||||||
|
|
||||||
|
image_size = hf_config.vision_config.image_size
|
||||||
|
patch_size = hf_config.vision_config.patch_size
|
||||||
|
downsample_ratio = int(
|
||||||
|
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
|
||||||
|
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
|
||||||
|
chunks_per_image = prod(mm_kwargs["patches_per_image"])
|
||||||
|
total_num_patches = chunks_per_image * tokens_per_patch
|
||||||
|
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
|
||||||
|
0][1] # x-y seperator tokens
|
||||||
|
total_tokens = total_num_patches.item() + num_tiles.item(
|
||||||
|
) + 3 # image start, image, image end
|
||||||
|
|
||||||
|
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
|
||||||
|
max_model_len,
|
||||||
|
mm_counts=mm_config.limit_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total_tokens == profiled_tokens["image"]
|
||||||
|
assert total_tokens == sum(
|
||||||
|
placeholder.length for placeholder in
|
||||||
|
decoder_dummy_data.multi_modal_placeholders["image"])
|
||||||
@ -391,7 +391,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||||
max_model_len=10240),
|
max_model_len=10240,
|
||||||
|
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501
|
||||||
|
),
|
||||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||||
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
|
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
|
||||||
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
|
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user