[UT][intel GPU] use current_platform instead of device hardcode in v1 tests (#20169)

Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
This commit is contained in:
Liangliang Ma 2025-07-02 09:06:04 +08:00 committed by GitHub
parent 3be8d312a2
commit a0389e0554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 44 additions and 26 deletions

View File

@ -34,7 +34,6 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import cuda_device_count_stateless
logger = init_logger(__name__)
@ -1094,7 +1093,8 @@ def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context
in current process."""
return cuda_device_count_stateless()
from vllm.platforms import current_platform
return current_platform.device_count()
temp_dir = tempfile.gettempdir()

View File

@ -6,12 +6,13 @@ import pytest
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = "cuda"
DEVICE = current_platform.device_type
@pytest.fixture
@ -21,7 +22,7 @@ def rejection_sampler():
def create_logits_tensor(output_token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids)
@ -41,8 +42,8 @@ def create_sampling_metadata(
top_p: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
is used.
"""
generators = generators or {}

View File

@ -2,25 +2,26 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
from torch import Generator
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
is_flashinfer_available)
DEVICE = "cuda"
DEVICE = current_platform.device_type
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
if is_flashinfer_available:
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
@pytest.fixture(autouse=True)
def reset_default_device():
"""
Explicitly set the default device, which can affect subsequent tests.
Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem.
"""
original_device = torch.get_default_device()
@ -58,8 +59,8 @@ def test_flashinfer_sampler():
This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation.
NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare
NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
'''

View File

@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
num_speculative_tokens=k,
)
vllm_config = VllmConfig(model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device="cuda"),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
return EagleProposer(vllm_config=vllm_config, device='cuda')
return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
def test_prepare_inputs():
@ -59,7 +62,7 @@ def test_prepare_inputs():
a, a + 1, ..., a + b - n2 - 1,
a + b, a + b + 1, ..., a + b + c - n3 - 1]
"""
device = torch.device('cuda')
device = torch.device(current_platform.device_type)
# a = 4, b = 7, c = 5
# n1 = 1, n2 = 3, n3 = 2
@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens):
# Use GPU device
device = torch.device('cuda')
device = torch.device(current_platform.device_type)
# Setup test parameters
batch_size = 2

View File

@ -8,6 +8,7 @@ import numpy as np
import pytest
import torch
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
@ -19,7 +20,8 @@ VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
f"{current_platform.device_type}:{i}"
for i in range(min(current_platform.device_count(), 2))
]
MAX_NUM_PROMPT_TOKENS = 64

View File

@ -9,6 +9,7 @@ import torch
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig, set_current_vllm_config)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
@ -23,7 +24,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
BLOCK_SIZE = 16
NUM_BLOCKS = 10
DEVICE = "cuda"
DEVICE = current_platform.device_type
def initialize_kv_cache(runner: GPUModelRunner):

View File

@ -18,7 +18,7 @@ from typing_extensions import ParamSpec
import vllm._C # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import import_pynvml
from vllm.utils import cuda_device_count_stateless, import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
@ -401,6 +401,10 @@ class CudaPlatformBase(Platform):
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -12,6 +12,7 @@ from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
@ -446,3 +447,7 @@ class RocmPlatform(Platform):
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()

View File

@ -189,4 +189,4 @@ class XPUPlatform(Platform):
@classmethod
def device_count(cls) -> int:
return torch.xpu.device_count()
return torch.xpu.device_count()

View File

@ -217,7 +217,8 @@ try:
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
if current_platform.is_rocm():
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
if TYPE_CHECKING: