mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[UT][intel GPU] use current_platform instead of device hardcode in v1 tests (#20169)
Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
This commit is contained in:
parent
3be8d312a2
commit
a0389e0554
@ -34,7 +34,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||||
from vllm.utils import cuda_device_count_stateless
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -1094,7 +1093,8 @@ def num_gpus_available():
|
|||||||
"""Get number of GPUs without initializing the CUDA context
|
"""Get number of GPUs without initializing the CUDA context
|
||||||
in current process."""
|
in current process."""
|
||||||
|
|
||||||
return cuda_device_count_stateless()
|
from vllm.platforms import current_platform
|
||||||
|
return current_platform.device_count()
|
||||||
|
|
||||||
|
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
|
|||||||
@ -6,12 +6,13 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
|
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
|
||||||
RejectionSampler)
|
RejectionSampler)
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|
||||||
DEVICE = "cuda"
|
DEVICE = current_platform.device_type
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -2,19 +2,20 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
|
||||||
from torch import Generator
|
from torch import Generator
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||||
is_flashinfer_available)
|
is_flashinfer_available)
|
||||||
|
|
||||||
DEVICE = "cuda"
|
DEVICE = current_platform.device_type
|
||||||
|
|
||||||
BATCH_SIZE = 1024
|
BATCH_SIZE = 1024
|
||||||
VOCAB_SIZE = 128 * 1024
|
VOCAB_SIZE = 128 * 1024
|
||||||
|
|
||||||
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
||||||
|
if is_flashinfer_available:
|
||||||
|
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
|||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||||
|
|
||||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
|
|||||||
num_speculative_tokens=k,
|
num_speculative_tokens=k,
|
||||||
)
|
)
|
||||||
|
|
||||||
vllm_config = VllmConfig(model_config=model_config,
|
vllm_config = VllmConfig(
|
||||||
cache_config=CacheConfig(),
|
model_config=model_config,
|
||||||
speculative_config=speculative_config,
|
cache_config=CacheConfig(),
|
||||||
device_config=DeviceConfig(device="cuda"),
|
speculative_config=speculative_config,
|
||||||
parallel_config=ParallelConfig(),
|
device_config=DeviceConfig(device=current_platform.device_type),
|
||||||
load_config=LoadConfig(),
|
parallel_config=ParallelConfig(),
|
||||||
scheduler_config=SchedulerConfig())
|
load_config=LoadConfig(),
|
||||||
|
scheduler_config=SchedulerConfig())
|
||||||
|
|
||||||
return EagleProposer(vllm_config=vllm_config, device='cuda')
|
return EagleProposer(vllm_config=vllm_config,
|
||||||
|
device=current_platform.device_type)
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_inputs():
|
def test_prepare_inputs():
|
||||||
@ -59,7 +62,7 @@ def test_prepare_inputs():
|
|||||||
a, a + 1, ..., a + b - n2 - 1,
|
a, a + 1, ..., a + b - n2 - 1,
|
||||||
a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||||
"""
|
"""
|
||||||
device = torch.device('cuda')
|
device = torch.device(current_platform.device_type)
|
||||||
|
|
||||||
# a = 4, b = 7, c = 5
|
# a = 4, b = 7, c = 5
|
||||||
# n1 = 1, n2 = 3, n3 = 2
|
# n1 = 1, n2 = 3, n3 = 2
|
||||||
@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
|||||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||||
def test_propose(num_speculative_tokens):
|
def test_propose(num_speculative_tokens):
|
||||||
# Use GPU device
|
# Use GPU device
|
||||||
device = torch.device('cuda')
|
device = torch.device(current_platform.device_type)
|
||||||
|
|
||||||
# Setup test parameters
|
# Setup test parameters
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
@ -19,7 +20,8 @@ VOCAB_SIZE = 1024
|
|||||||
NUM_OUTPUT_TOKENS = 20
|
NUM_OUTPUT_TOKENS = 20
|
||||||
MAX_PROMPT_SIZE = 100
|
MAX_PROMPT_SIZE = 100
|
||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"{current_platform.device_type}:{i}"
|
||||||
|
for i in range(min(current_platform.device_count(), 2))
|
||||||
]
|
]
|
||||||
MAX_NUM_PROMPT_TOKENS = 64
|
MAX_NUM_PROMPT_TOKENS = 64
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import torch
|
|||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig, VllmConfig, set_current_vllm_config)
|
SchedulerConfig, VllmConfig, set_current_vllm_config)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils import GiB_bytes
|
||||||
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||||
@ -23,7 +24,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|||||||
|
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
NUM_BLOCKS = 10
|
NUM_BLOCKS = 10
|
||||||
DEVICE = "cuda"
|
DEVICE = current_platform.device_type
|
||||||
|
|
||||||
|
|
||||||
def initialize_kv_cache(runner: GPUModelRunner):
|
def initialize_kv_cache(runner: GPUModelRunner):
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from typing_extensions import ParamSpec
|
|||||||
import vllm._C # noqa
|
import vllm._C # noqa
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import import_pynvml
|
from vllm.utils import cuda_device_count_stateless, import_pynvml
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
@ -401,6 +401,10 @@ class CudaPlatformBase(Platform):
|
|||||||
pg._register_backend(device, backend_type, backend_class)
|
pg._register_backend(device, backend_type, backend_class)
|
||||||
return pg
|
return pg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def device_count(cls) -> int:
|
||||||
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from torch.distributed.distributed_c10d import is_nccl_available
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
@ -446,3 +447,7 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
pg._register_backend(device, backend_type, backend_class)
|
pg._register_backend(device, backend_type, backend_class)
|
||||||
return pg
|
return pg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def device_count(cls) -> int:
|
||||||
|
return cuda_device_count_stateless()
|
||||||
|
|||||||
@ -217,7 +217,8 @@ try:
|
|||||||
is_vllm_fa = True
|
is_vllm_fa = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# For rocm use upstream flash attention
|
# For rocm use upstream flash attention
|
||||||
from flash_attn import flash_attn_varlen_func
|
if current_platform.is_rocm():
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
is_vllm_fa = False
|
is_vllm_fa = False
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user