From 8a044754bd083671e4bb09a68b1edae9610dfccc Mon Sep 17 00:00:00 2001 From: Chaojun Zhang Date: Tue, 26 Aug 2025 04:09:26 +0800 Subject: [PATCH] [XPU] Delay BF16 check to worker init for spawn compatibility (#22979) Signed-off-by: chzhang --- vllm/platforms/cuda.py | 20 +++++++++++++++++++ vllm/platforms/interface.py | 7 +++++++ vllm/platforms/rocm.py | 20 +++++++++++++++++++ vllm/platforms/xpu.py | 37 +++++++++++------------------------- vllm/v1/worker/gpu_worker.py | 22 +-------------------- vllm/v1/worker/xpu_worker.py | 1 + 6 files changed, 60 insertions(+), 47 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 134ba36e5e73..c0e0fe35e402 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -518,6 +518,26 @@ class CudaPlatformBase(Platform): supported = True return supported + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 00bc555288e8..f6c17de86d05 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -572,6 +572,13 @@ class Platform: """ return False + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + """ + Check if the dtype is supported by the current platform. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 323ec591c50a..85b2fe2e480c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -462,3 +462,23 @@ class RocmPlatform(Platform): def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, model_config: "ModelConfig") -> bool: return True + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index af24437f649f..235e5d8294e5 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -97,13 +97,6 @@ class XPUPlatform(Platform): from vllm.config import CompilationLevel vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 - # Instances created using VllmConfig() typically have model_config as - # None by default. The modification involves adding a check to prevent - # potential null exceptions check and update model config. - if model_config is not None and model_config.dtype == torch.bfloat16 \ - and not cls.device_support_bf16(): - model_config.dtype = torch.float16 - # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config @@ -162,30 +155,11 @@ class XPUPlatform(Platform): torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) - @classmethod - def device_support_bf16(cls) -> bool: - device_name = cls.get_device_name().lower() - if cls.is_client_gpu_a770(): - logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," - " fallback to float16") - return False - else: - logger.info( - "Device name %s supports bfloat16. Please file an issue " - "if you encounter any accuracy problems with bfloat16.", - device_name) - return True - @classmethod def is_data_center_gpu(cls) -> bool: device_name = cls.get_device_name().lower() return device_name.count("data center gpu") > 0 - @classmethod - def is_client_gpu_a770(cls) -> bool: - device_name = cls.get_device_name().lower() - return device_name.count("a770") > 0 - @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa @@ -197,3 +171,14 @@ class XPUPlatform(Platform): @classmethod def device_count(cls) -> int: return torch.xpu.device_count() + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + device_name = cls.get_device_name().lower() + # client gpu a770 + if device_name.count("a770") > 0: + raise ValueError( + "Intel Arc A770 have bfloat16 accuracy known issue. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 1688b8b83e87..0dca45a75921 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -167,7 +167,7 @@ class Worker(WorkerBase): self.device = torch.device(f"cuda:{self.local_rank}") current_platform.set_device(self.device) - _check_if_gpu_supports_dtype(self.model_config.dtype) + current_platform.check_if_supports_dtype(self.model_config.dtype) gc.collect() torch.cuda.empty_cache() @@ -612,23 +612,3 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 134d83925265..17288cda8ecc 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -145,6 +145,7 @@ class XPUWorker(Worker): ): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) + current_platform.check_if_supports_dtype(self.model_config.dtype) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( self.local_rank).total_memory