diff --git a/tests/conftest.py b/tests/conftest.py index feb52e26300a..b294b50a5cdd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,6 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import cuda_device_count_stateless logger = init_logger(__name__) @@ -1094,7 +1093,8 @@ def num_gpus_available(): """Get number of GPUs without initializing the CUDA context in current process.""" - return cuda_device_count_stateless() + from vllm.platforms import current_platform + return current_platform.device_count() temp_dir = tempfile.gettempdir() diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index f35c3e194fa7..1f2bdb3c5ff6 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -6,12 +6,13 @@ import pytest import torch import torch.nn.functional as F +from vllm.platforms import current_platform from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -DEVICE = "cuda" +DEVICE = current_platform.device_type @pytest.fixture @@ -21,7 +22,7 @@ def rejection_sampler(): def create_logits_tensor(output_token_ids: list[list[int]], vocab_size: int = 100) -> torch.Tensor: - """Helper function to create logits tensor that + """Helper function to create logits tensor that will produce desired token ids on argmax""" token_ids = [tokens[:-1] for tokens in output_token_ids] num_total_tokens = sum(len(tokens) for tokens in token_ids) @@ -41,8 +42,8 @@ def create_sampling_metadata( top_p: Optional[torch.Tensor] = None, generators: Optional[dict[int, Any]] = None, ) -> SamplingMetadata: - """Create a v1 sampling metadata object with all_greedy set - to the given value. Either all greedy or all random sampling + """Create a v1 sampling metadata object with all_greedy set + to the given value. Either all greedy or all random sampling is used. """ generators = generators or {} diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index 9d695cd91a97..ccf38c31d39e 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -2,25 +2,26 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch -from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs from torch import Generator from vllm.platforms import current_platform from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, is_flashinfer_available) -DEVICE = "cuda" +DEVICE = current_platform.device_type BATCH_SIZE = 1024 VOCAB_SIZE = 128 * 1024 FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available +if is_flashinfer_available: + from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs @pytest.fixture(autouse=True) def reset_default_device(): """ - Explicitly set the default device, which can affect subsequent tests. + Explicitly set the default device, which can affect subsequent tests. Adding this fixture helps avoid this problem. """ original_device = torch.get_default_device() @@ -58,8 +59,8 @@ def test_flashinfer_sampler(): This test verifies that the FlashInfer top-k and top-p sampling implementation produces the same results as the Python implementation. - NOTE: FlashInfer did not directly expose an interface for fused top-k and - top-p prob renorm (it did provide fused sampling but we cannot compare + NOTE: FlashInfer did not directly expose an interface for fused top-k and + top-p prob renorm (it did provide fused sampling but we cannot compare sampling results due to randomness), so we will compare the probability renormed consequently by top-k and then top-p of FlashInfer implementation. ''' diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index c93b7f57c041..5efab2c14407 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer model_dir = "meta-llama/Llama-3.1-8B-Instruct" @@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer: num_speculative_tokens=k, ) - vllm_config = VllmConfig(model_config=model_config, - cache_config=CacheConfig(), - speculative_config=speculative_config, - device_config=DeviceConfig(device="cuda"), - parallel_config=ParallelConfig(), - load_config=LoadConfig(), - scheduler_config=SchedulerConfig()) + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig()) - return EagleProposer(vllm_config=vllm_config, device='cuda') + return EagleProposer(vllm_config=vllm_config, + device=current_platform.device_type) def test_prepare_inputs(): @@ -59,7 +62,7 @@ def test_prepare_inputs(): a, a + 1, ..., a + b - n2 - 1, a + b, a + b + 1, ..., a + b + c - n3 - 1] """ - device = torch.device('cuda') + device = torch.device(current_platform.device_type) # a = 4, b = 7, c = 5 # n1 = 1, n2 = 3, n3 = 2 @@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) def test_propose(num_speculative_tokens): # Use GPU device - device = torch.device('cuda') + device = torch.device(current_platform.device_type) # Setup test parameters batch_size = 2 diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 9e5e06cdc1f5..59b28e675c25 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -8,6 +8,7 @@ import numpy as np import pytest import torch +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata @@ -19,7 +20,8 @@ VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 MAX_PROMPT_SIZE = 100 CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) + f"{current_platform.device_type}:{i}" + for i in range(min(current_platform.device_count(), 2)) ] MAX_NUM_PROMPT_TOKENS = 64 diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index c42200b70da3..2e1deecbd9e6 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -9,6 +9,7 @@ import torch from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, @@ -23,7 +24,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner BLOCK_SIZE = 16 NUM_BLOCKS = 10 -DEVICE = "cuda" +DEVICE = current_platform.device_type def initialize_kv_cache(runner: GPUModelRunner): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 879d094f6578..15cab757d2c0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -18,7 +18,7 @@ from typing_extensions import ParamSpec import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import import_pynvml +from vllm.utils import cuda_device_count_stateless, import_pynvml from .interface import DeviceCapability, Platform, PlatformEnum, _Backend @@ -401,6 +401,10 @@ class CudaPlatformBase(Platform): pg._register_backend(device, backend_type, backend_class) return pg + @classmethod + def device_count(cls) -> int: + return cuda_device_count_stateless() + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ee53a76ceb6d..4550ef570684 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -12,6 +12,7 @@ from torch.distributed.distributed_c10d import is_nccl_available import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum, _Backend @@ -446,3 +447,7 @@ class RocmPlatform(Platform): pg._register_backend(device, backend_type, backend_class) return pg + + @classmethod + def device_count(cls) -> int: + return cuda_device_count_stateless() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index f361f5e2616e..61a0453dcbc8 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -189,4 +189,4 @@ class XPUPlatform(Platform): @classmethod def device_count(cls) -> int: - return torch.xpu.device_count() + return torch.xpu.device_count() \ No newline at end of file diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d45ec04472a6..39379b11863c 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -217,7 +217,8 @@ try: is_vllm_fa = True except ImportError: # For rocm use upstream flash attention - from flash_attn import flash_attn_varlen_func + if current_platform.is_rocm(): + from flash_attn import flash_attn_varlen_func is_vllm_fa = False if TYPE_CHECKING: