[Core] Don't do platform detection at import time (#12933)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-02-11 02:25:25 -05:00 committed by GitHub
parent 58047c6f04
commit c320ca8edd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 8 deletions

View File

@ -8,11 +8,11 @@ from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
import torch.nn as nn
from typing_extensions import TypeVar
import vllm.platforms
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
@ -108,8 +108,8 @@ class ExecutorBase(ABC):
"""
# NOTE: This is logged in the executor because there can be >1 workers.
logger.info("# %s blocks: %d, # CPU blocks: %d",
current_platform.dispatch_key, num_gpu_blocks,
num_cpu_blocks)
vllm.platforms.current_platform.dispatch_key,
num_gpu_blocks, num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",

View File

@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import msgspec
import vllm.platforms
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
@ -54,10 +54,10 @@ try:
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
device_key = current_platform.ray_device_key
device_key = vllm.platforms.current_platform.ray_device_key
if not device_key:
raise RuntimeError("current platform %s does not support ray.",
current_platform.device_name)
vllm.platforms.current_platform.device_name)
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
)[device_key]
return node_id, gpu_ids

View File

@ -334,10 +334,10 @@ class NvmlCudaPlatform(CudaPlatformBase):
if (len(set(device_names)) > 1
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
logger.warning(
"Detected different devices in the system: \n%s\nPlease"
"Detected different devices in the system: %s. Please"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.",
"\n".join(device_names),
", ".join(device_names),
)