From 5c4c08f6f1609960b047c8b9d6aa003e9afc2897 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 10 May 2025 01:16:12 +0800 Subject: [PATCH] [Misc] Auto fallback to float16 for pre-Ampere GPUs when detected bfloat16 config (#17265) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/config.py | 46 +++++++++++++++++-------------------- vllm/platforms/cpu.py | 16 ++++++++++++- vllm/platforms/cuda.py | 13 +++++++++++ vllm/platforms/interface.py | 8 +++++++ 4 files changed, 57 insertions(+), 26 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ac1dc960ccbe2..cc185b1d5bcb9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -7,7 +7,6 @@ import hashlib import inspect import json import re -import sys import textwrap import warnings from collections import Counter @@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, QuantizationMethods, get_quantization_config) from vllm.model_executor.models import ModelRegistry -from vllm.platforms import CpuArchEnum, current_platform +from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -2988,6 +2987,7 @@ def _get_and_verify_dtype( if isinstance(dtype, str): dtype = dtype.lower() if dtype == "auto": + # Set default dtype from model config if config_dtype == torch.float32: # Following common practice, we use float16 for float32 models torch_dtype = torch.float16 @@ -2995,37 +2995,33 @@ def _get_and_verify_dtype( torch_dtype = config_dtype if config.model_type == "plamo2": - logger.info( + logger.warning( "For PLaMo2, we cast models to bfloat16 instead of using " "float16 by default. This is because float16 does not work." ) torch_dtype = torch.bfloat16 + # Deal with torch dtype fallback for device compatibility. from vllm.platforms import current_platform - if (current_platform.is_cpu() - and current_platform.get_cpu_architecture() - == CpuArchEnum.POWERPC - and (config_dtype == torch.float16 - or config_dtype == torch.float32)): - logger.info( - "For POWERPC, we cast models to bfloat16 instead of " - "using float16 by default. Float16 is not currently " - "supported for POWERPC.") - torch_dtype = torch.bfloat16 + if torch_dtype not in current_platform.supported_dtypes: + device_name = current_platform.get_device_name() - # TODO: change this condition to check if the platform support bf16 - # instead of checking the OS. For instance M2 shall supports bf16 - # already. But we need to modify `cpu_extension.cmake` to activate - # the feature in the build. - if (current_platform.is_cpu() and sys.platform.startswith("darwin") - and current_platform.get_cpu_architecture() - == CpuArchEnum.ARM and config_dtype == torch.bfloat16): - logger.info("For macOS with Apple Silicon, currently bfloat16 " - "is not supported. Setting dtype to float16.") - torch_dtype = torch.float16 + if ((capability := current_platform.get_device_capability()) + is None): + compute_str = "" + else: + version_str = capability.as_version_str() + compute_str = f" (with compute capability {version_str})" + fallback_dtype = current_platform.supported_dtypes[0] + logger.warning( + "Your %s device%s doesn't support %s. " \ + "Falling back to %s for compatibility.", + device_name, compute_str, torch_dtype, fallback_dtype + ) + torch_dtype = fallback_dtype - if current_platform.is_hpu() and config_dtype == torch.float16: - logger.info( + if current_platform.is_hpu() and torch_dtype == torch.float16: + logger.warning( "For HPU, we cast models to bfloat16 instead of " "using float16 by default. Please specify `dtype` if you " "want to use float16.") diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index e45522a4c407e..d286c89395126 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -10,7 +10,7 @@ import torch from vllm.logger import init_logger -from .interface import Platform, PlatformEnum, _Backend +from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend logger = init_logger(__name__) @@ -26,6 +26,20 @@ class CpuPlatform(Platform): device_type: str = "cpu" dispatch_key: str = "CPU" + @property + def supported_dtypes(self) -> list: + if self.get_cpu_architecture() == CpuArchEnum.POWERPC: + return [torch.bfloat16, torch.float32] + elif sys.platform.startswith( + "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM: + # TODO: change this condition to check if the platform support bf16 + # instead of checking the OS. For instance M2 shall supports bf16 + # already. But we need to modify `cpu_extension.cmake` to activate + # the feature in the build. + return [torch.bfloat16, torch.float32] + # x86/aarch64 CPU has supported both bf16 and fp16 natively. + return [torch.bfloat16, torch.float16, torch.float32] + @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "cpu" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ab03dece8c136..f116285870eca 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -73,6 +73,19 @@ class CudaPlatformBase(Platform): ray_device_key: str = "GPU" device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + @property + def supported_dtypes(self) -> List[torch.dtype]: + if self.has_device_capability(80): + # Ampere and Hopper or later NVIDIA GPUs. + return [torch.bfloat16, torch.float16, torch.float32] + elif (not self.has_device_capability(80) + ) and self.has_device_capability(60): + # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported + return [torch.float16, torch.float32] + # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported, + # though vLLM doesn't support these GPUs. + return [torch.float32] + @classmethod def get_device_capability(cls, device_id: int = 0 diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index e9c7f0cb5835f..68b90796ece23 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -122,6 +122,14 @@ class Platform: additional_env_vars: list[str] = [] + @property + def supported_dtypes(self) -> list[torch.dtype]: + """Returns the supported dtypes for the current platform.""" + # Be careful with the order of the dtypes. The first dtype will + # be used as the default dtype fallback for the current platform, + # when encountering unsupported dtypes in "auto" dtype. + return [torch.bfloat16, torch.float16, torch.float32] + def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA