diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 13e813952b30a..15ea9f7d60fff 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -31,6 +31,7 @@ from vllm.model_executor.utils import replace_parameter from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( + DeepGemmQuantScaleFMT, fp8_gemm_nt, is_deep_gemm_e8m0_used, is_deep_gemm_supported, @@ -247,7 +248,6 @@ class W8A8BlockFp8LinearOp: self.act_quant_group_shape = act_quant_group_shape self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) - self.is_blackwell = current_platform.is_device_capability_family(100) self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() # Get the correct blockscale mul and input quant operations. @@ -303,7 +303,7 @@ class W8A8BlockFp8LinearOp: weight: torch.Tensor, weight_scale: torch.Tensor, ) -> torch.Tensor: - if self.use_deep_gemm_e8m0 and self.is_blackwell: + if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0: q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( input_2d, group_size=self.act_quant_group_shape.col, diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 3d4f8449ad3b6..bcda46421e827 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -32,15 +32,34 @@ class DeepGemmQuantScaleFMT(Enum): # element contains 4 scale values. UE8M0 = 2 - @staticmethod - def from_oracle() -> "DeepGemmQuantScaleFMT": - if not is_deep_gemm_e8m0_used(): - return DeepGemmQuantScaleFMT.FLOAT32 - return ( - DeepGemmQuantScaleFMT.UE8M0 - if current_platform.is_device_capability_family(100) - else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 + @classmethod + def init_oracle_cache(cls) -> None: + """Initialize the oracle decision and store it in the class cache""" + cached = getattr(cls, "_oracle_cache", None) + if cached is not None: + return + + use_e8m0 = ( + envs.VLLM_USE_DEEP_GEMM_E8M0 + and is_deep_gemm_supported() + and (_fp8_gemm_nt_impl is not None) ) + if not use_e8m0: + cls._oracle_cache = cls.FLOAT32 # type: ignore + return + + cls._oracle_cache = ( # type: ignore + cls.UE8M0 + if current_platform.is_device_capability_family(100) + else cls.FLOAT32_CEIL_UE8M0 + ) + + @classmethod + def from_oracle(cls) -> "DeepGemmQuantScaleFMT": + """Return the pre-initialized oracle decision""" + cached = getattr(cls, "_oracle_cache", None) + assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized" + return cached @functools.cache @@ -149,6 +168,7 @@ def _lazy_init() -> None: _transform_sf_into_required_layout_impl = getattr( _dg, "transform_sf_into_required_layout", None ) + DeepGemmQuantScaleFMT.init_oracle_cache() def get_num_sms() -> int: