[CI/Build] Fix broken mm processor test Mistral-3-large (#30597)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-12-13 20:43:01 +08:00 committed by GitHub
parent 64251f48df
commit e5db3e2774
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,6 +8,7 @@ from typing import Any, TypeAlias
import numpy as np import numpy as np
import pytest import pytest
import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
@ -35,6 +36,7 @@ from vllm.tokenizers import cached_tokenizer_from_config
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
from ....utils import create_new_process_for_each_test
from ...registry import HF_EXAMPLE_MODELS from ...registry import HF_EXAMPLE_MODELS
from ...utils import dummy_hf_overrides from ...utils import dummy_hf_overrides
from .test_common import get_model_ids_to_test, get_text_token_prompts from .test_common import get_model_ids_to_test, get_text_token_prompts
@ -136,6 +138,7 @@ def create_batched_mm_kwargs(
) )
# TODO(Isotr0py): Don't initalize model during test
@contextmanager @contextmanager
def initialize_dummy_model( def initialize_dummy_model(
model_cls: type[nn.Module], model_cls: type[nn.Module],
@ -150,16 +153,21 @@ def initialize_dummy_model(
backend="nccl", backend="nccl",
) )
initialize_model_parallel(tensor_model_parallel_size=1) initialize_model_parallel(tensor_model_parallel_size=1)
current_device = torch.get_default_device()
vllm_config = VllmConfig(model_config=model_config) vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config): with set_current_vllm_config(vllm_config=vllm_config):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
torch.set_default_device(current_platform.device_type)
model = model_cls(vllm_config=vllm_config) model = model_cls(vllm_config=vllm_config)
torch.set_default_device(current_device)
yield model yield model
del model del model
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", get_model_ids_to_test()) @pytest.mark.parametrize("model_id", get_model_ids_to_test())
def test_model_tensor_schema(model_id: str): def test_model_tensor_schema(model_id: str):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)