[UT][intel GPU] use current_platform instead of device hardcode in v1 tests (#20169)

Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
This commit is contained in:
Liangliang Ma 2025-07-02 09:06:04 +08:00 committed by GitHub
parent 3be8d312a2
commit a0389e0554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 44 additions and 26 deletions

View File

@ -34,7 +34,6 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
@ -1094,7 +1093,8 @@ def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context """Get number of GPUs without initializing the CUDA context
in current process.""" in current process."""
return cuda_device_count_stateless() from vllm.platforms import current_platform
return current_platform.device_count()
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()

View File

@ -6,12 +6,13 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler) RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = "cuda" DEVICE = current_platform.device_type
@pytest.fixture @pytest.fixture
@ -21,7 +22,7 @@ def rejection_sampler():
def create_logits_tensor(output_token_ids: list[list[int]], def create_logits_tensor(output_token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor: vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that """Helper function to create logits tensor that
will produce desired token ids on argmax""" will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids] token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids) num_total_tokens = sum(len(tokens) for tokens in token_ids)
@ -41,8 +42,8 @@ def create_sampling_metadata(
top_p: Optional[torch.Tensor] = None, top_p: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None, generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata: ) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set """Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling to the given value. Either all greedy or all random sampling
is used. is used.
""" """
generators = generators or {} generators = generators or {}

View File

@ -2,25 +2,26 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch import torch
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
from torch import Generator from torch import Generator
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
is_flashinfer_available) is_flashinfer_available)
DEVICE = "cuda" DEVICE = current_platform.device_type
BATCH_SIZE = 1024 BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024 VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
if is_flashinfer_available:
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_default_device(): def reset_default_device():
""" """
Explicitly set the default device, which can affect subsequent tests. Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem. Adding this fixture helps avoid this problem.
""" """
original_device = torch.get_default_device() original_device = torch.get_default_device()
@ -58,8 +59,8 @@ def test_flashinfer_sampler():
This test verifies that the FlashInfer top-k and top-p sampling This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation. implementation produces the same results as the Python implementation.
NOTE: FlashInfer did not directly expose an interface for fused top-k and NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation. renormed consequently by top-k and then top-p of FlashInfer implementation.
''' '''

View File

@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig) VllmConfig)
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
num_speculative_tokens=k, num_speculative_tokens=k,
) )
vllm_config = VllmConfig(model_config=model_config, vllm_config = VllmConfig(
cache_config=CacheConfig(), model_config=model_config,
speculative_config=speculative_config, cache_config=CacheConfig(),
device_config=DeviceConfig(device="cuda"), speculative_config=speculative_config,
parallel_config=ParallelConfig(), device_config=DeviceConfig(device=current_platform.device_type),
load_config=LoadConfig(), parallel_config=ParallelConfig(),
scheduler_config=SchedulerConfig()) load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
return EagleProposer(vllm_config=vllm_config, device='cuda') return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
def test_prepare_inputs(): def test_prepare_inputs():
@ -59,7 +62,7 @@ def test_prepare_inputs():
a, a + 1, ..., a + b - n2 - 1, a, a + 1, ..., a + b - n2 - 1,
a + b, a + b + 1, ..., a + b + c - n3 - 1] a + b, a + b + 1, ..., a + b + c - n3 - 1]
""" """
device = torch.device('cuda') device = torch.device(current_platform.device_type)
# a = 4, b = 7, c = 5 # a = 4, b = 7, c = 5
# n1 = 1, n2 = 3, n3 = 2 # n1 = 1, n2 = 3, n3 = 2
@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens): def test_propose(num_speculative_tokens):
# Use GPU device # Use GPU device
device = torch.device('cuda') device = torch.device(current_platform.device_type)
# Setup test parameters # Setup test parameters
batch_size = 2 batch_size = 2

View File

@ -8,6 +8,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
@ -19,7 +20,8 @@ VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20 NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100 MAX_PROMPT_SIZE = 100
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"{current_platform.device_type}:{i}"
for i in range(min(current_platform.device_count(), 2))
] ]
MAX_NUM_PROMPT_TOKENS = 64 MAX_NUM_PROMPT_TOKENS = 64

View File

@ -9,6 +9,7 @@ import torch
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig, set_current_vllm_config) SchedulerConfig, VllmConfig, set_current_vllm_config)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
@ -23,7 +24,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
BLOCK_SIZE = 16 BLOCK_SIZE = 16
NUM_BLOCKS = 10 NUM_BLOCKS = 10
DEVICE = "cuda" DEVICE = current_platform.device_type
def initialize_kv_cache(runner: GPUModelRunner): def initialize_kv_cache(runner: GPUModelRunner):

View File

@ -18,7 +18,7 @@ from typing_extensions import ParamSpec
import vllm._C # noqa import vllm._C # noqa
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import import_pynvml from vllm.utils import cuda_device_count_stateless, import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
@ -401,6 +401,10 @@ class CudaPlatformBase(Platform):
pg._register_backend(device, backend_type, backend_class) pg._register_backend(device, backend_type, backend_class)
return pg return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -12,6 +12,7 @@ from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
@ -446,3 +447,7 @@ class RocmPlatform(Platform):
pg._register_backend(device, backend_type, backend_class) pg._register_backend(device, backend_type, backend_class)
return pg return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()

View File

@ -189,4 +189,4 @@ class XPUPlatform(Platform):
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return torch.xpu.device_count() return torch.xpu.device_count()

View File

@ -217,7 +217,8 @@ try:
is_vllm_fa = True is_vllm_fa = True
except ImportError: except ImportError:
# For rocm use upstream flash attention # For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func if current_platform.is_rocm():
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False is_vllm_fa = False
if TYPE_CHECKING: if TYPE_CHECKING: