mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[UT][intel GPU] use current_platform instead of device hardcode in v1 tests (#20169)
Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
This commit is contained in:
parent
3be8d312a2
commit
a0389e0554
@ -34,7 +34,6 @@ from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -1094,7 +1093,8 @@ def num_gpus_available():
|
||||
"""Get number of GPUs without initializing the CUDA context
|
||||
in current process."""
|
||||
|
||||
return cuda_device_count_stateless()
|
||||
from vllm.platforms import current_platform
|
||||
return current_platform.device_count()
|
||||
|
||||
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
@ -6,12 +6,13 @@ import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
|
||||
RejectionSampler)
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -2,19 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
||||
from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
is_flashinfer_available)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
VOCAB_SIZE = 128 * 1024
|
||||
|
||||
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
||||
if is_flashinfer_available:
|
||||
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
|
||||
num_speculative_tokens=k,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(model_config=model_config,
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device="cuda"),
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig())
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config, device='cuda')
|
||||
return EagleProposer(vllm_config=vllm_config,
|
||||
device=current_platform.device_type)
|
||||
|
||||
|
||||
def test_prepare_inputs():
|
||||
@ -59,7 +62,7 @@ def test_prepare_inputs():
|
||||
a, a + 1, ..., a + b - n2 - 1,
|
||||
a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||
"""
|
||||
device = torch.device('cuda')
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# a = 4, b = 7, c = 5
|
||||
# n1 = 1, n2 = 3, n3 = 2
|
||||
@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||
def test_propose(num_speculative_tokens):
|
||||
# Use GPU device
|
||||
device = torch.device('cuda')
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# Setup test parameters
|
||||
batch_size = 2
|
||||
|
||||
@ -8,6 +8,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
@ -19,7 +20,8 @@ VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
MAX_PROMPT_SIZE = 100
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
f"{current_platform.device_type}:{i}"
|
||||
for i in range(min(current_platform.device_count(), 2))
|
||||
]
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig, set_current_vllm_config)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||
@ -23,7 +24,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
DEVICE = "cuda"
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
|
||||
def initialize_kv_cache(runner: GPUModelRunner):
|
||||
|
||||
@ -18,7 +18,7 @@ from typing_extensions import ParamSpec
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import import_pynvml
|
||||
from vllm.utils import cuda_device_count_stateless, import_pynvml
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
@ -401,6 +401,10 @@ class CudaPlatformBase(Platform):
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
||||
@ -12,6 +12,7 @@ from torch.distributed.distributed_c10d import is_nccl_available
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
@ -446,3 +447,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@ -217,6 +217,7 @@ try:
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
# For rocm use upstream flash attention
|
||||
if current_platform.is_rocm():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = False
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user