mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 11:43:05 +08:00
[Refactor] Refactor for DeepGemmQuantScaleFMT using cache (#30898)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
1ab5213531
commit
3bd8335bd0
@ -31,6 +31,7 @@ from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
fp8_gemm_nt,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
@ -247,7 +248,6 @@ class W8A8BlockFp8LinearOp:
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
self.is_blackwell = current_platform.is_device_capability_family(100)
|
||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
@ -303,7 +303,7 @@ class W8A8BlockFp8LinearOp:
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.use_deep_gemm_e8m0 and self.is_blackwell:
|
||||
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
|
||||
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
input_2d,
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
|
||||
@ -32,15 +32,34 @@ class DeepGemmQuantScaleFMT(Enum):
|
||||
# element contains 4 scale values.
|
||||
UE8M0 = 2
|
||||
|
||||
@staticmethod
|
||||
def from_oracle() -> "DeepGemmQuantScaleFMT":
|
||||
if not is_deep_gemm_e8m0_used():
|
||||
return DeepGemmQuantScaleFMT.FLOAT32
|
||||
return (
|
||||
DeepGemmQuantScaleFMT.UE8M0
|
||||
if current_platform.is_device_capability_family(100)
|
||||
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||
@classmethod
|
||||
def init_oracle_cache(cls) -> None:
|
||||
"""Initialize the oracle decision and store it in the class cache"""
|
||||
cached = getattr(cls, "_oracle_cache", None)
|
||||
if cached is not None:
|
||||
return
|
||||
|
||||
use_e8m0 = (
|
||||
envs.VLLM_USE_DEEP_GEMM_E8M0
|
||||
and is_deep_gemm_supported()
|
||||
and (_fp8_gemm_nt_impl is not None)
|
||||
)
|
||||
if not use_e8m0:
|
||||
cls._oracle_cache = cls.FLOAT32 # type: ignore
|
||||
return
|
||||
|
||||
cls._oracle_cache = ( # type: ignore
|
||||
cls.UE8M0
|
||||
if current_platform.is_device_capability_family(100)
|
||||
else cls.FLOAT32_CEIL_UE8M0
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
|
||||
"""Return the pre-initialized oracle decision"""
|
||||
cached = getattr(cls, "_oracle_cache", None)
|
||||
assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
|
||||
return cached
|
||||
|
||||
|
||||
@functools.cache
|
||||
@ -149,6 +168,7 @@ def _lazy_init() -> None:
|
||||
_transform_sf_into_required_layout_impl = getattr(
|
||||
_dg, "transform_sf_into_required_layout", None
|
||||
)
|
||||
DeepGemmQuantScaleFMT.init_oracle_cache()
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user