mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 05:00:07 +08:00
[CI/Build] Fix broken mm processor test Mistral-3-large (#30597)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
64251f48df
commit
e5db3e2774
@ -8,6 +8,7 @@ from typing import Any, TypeAlias
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ from vllm.tokenizers import cached_tokenizer_from_config
|
|||||||
from vllm.utils.collection_utils import is_list_of
|
from vllm.utils.collection_utils import is_list_of
|
||||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||||
|
|
||||||
|
from ....utils import create_new_process_for_each_test
|
||||||
from ...registry import HF_EXAMPLE_MODELS
|
from ...registry import HF_EXAMPLE_MODELS
|
||||||
from ...utils import dummy_hf_overrides
|
from ...utils import dummy_hf_overrides
|
||||||
from .test_common import get_model_ids_to_test, get_text_token_prompts
|
from .test_common import get_model_ids_to_test, get_text_token_prompts
|
||||||
@ -136,6 +138,7 @@ def create_batched_mm_kwargs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Isotr0py): Don't initalize model during test
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def initialize_dummy_model(
|
def initialize_dummy_model(
|
||||||
model_cls: type[nn.Module],
|
model_cls: type[nn.Module],
|
||||||
@ -150,16 +153,21 @@ def initialize_dummy_model(
|
|||||||
backend="nccl",
|
backend="nccl",
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||||
|
|
||||||
|
current_device = torch.get_default_device()
|
||||||
vllm_config = VllmConfig(model_config=model_config)
|
vllm_config = VllmConfig(model_config=model_config)
|
||||||
with set_current_vllm_config(vllm_config=vllm_config):
|
with set_current_vllm_config(vllm_config=vllm_config):
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
torch.set_default_device(current_platform.device_type)
|
||||||
model = model_cls(vllm_config=vllm_config)
|
model = model_cls(vllm_config=vllm_config)
|
||||||
|
torch.set_default_device(current_device)
|
||||||
yield model
|
yield model
|
||||||
|
|
||||||
del model
|
del model
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
|
@create_new_process_for_each_test()
|
||||||
@pytest.mark.parametrize("model_id", get_model_ids_to_test())
|
@pytest.mark.parametrize("model_id", get_model_ids_to_test())
|
||||||
def test_model_tensor_schema(model_id: str):
|
def test_model_tensor_schema(model_id: str):
|
||||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user