[Core] Don't do platform detection at import time (#12933)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-02-11 02:25:25 -05:00 committed by GitHub
parent 58047c6f04
commit c320ca8edd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 8 deletions

View File

@ -8,11 +8,11 @@ from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
import torch.nn as nn import torch.nn as nn
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.platforms
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async from vllm.utils import make_async
@ -108,8 +108,8 @@ class ExecutorBase(ABC):
""" """
# NOTE: This is logged in the executor because there can be >1 workers. # NOTE: This is logged in the executor because there can be >1 workers.
logger.info("# %s blocks: %d, # CPU blocks: %d", logger.info("# %s blocks: %d, # CPU blocks: %d",
current_platform.dispatch_key, num_gpu_blocks, vllm.platforms.current_platform.dispatch_key,
num_cpu_blocks) num_gpu_blocks, num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size / max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len) self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx", logger.info("Maximum concurrency for %s tokens per request: %.2fx",

View File

@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import msgspec import msgspec
import vllm.platforms
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
@ -54,10 +54,10 @@ try:
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id() node_id = ray.get_runtime_context().get_node_id()
device_key = current_platform.ray_device_key device_key = vllm.platforms.current_platform.ray_device_key
if not device_key: if not device_key:
raise RuntimeError("current platform %s does not support ray.", raise RuntimeError("current platform %s does not support ray.",
current_platform.device_name) vllm.platforms.current_platform.device_name)
gpu_ids = ray.get_runtime_context().get_accelerator_ids( gpu_ids = ray.get_runtime_context().get_accelerator_ids(
)[device_key] )[device_key]
return node_id, gpu_ids return node_id, gpu_ids

View File

@ -334,10 +334,10 @@ class NvmlCudaPlatform(CudaPlatformBase):
if (len(set(device_names)) > 1 if (len(set(device_names)) > 1
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"): and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
logger.warning( logger.warning(
"Detected different devices in the system: \n%s\nPlease" "Detected different devices in the system: %s. Please"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.", "avoid unexpected behavior.",
"\n".join(device_names), ", ".join(device_names),
) )