mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:35:50 +08:00
[Performance][B200] Fix deepgemm prologue (#27897)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
478ee511de
commit
74a9a9faad
@ -232,6 +232,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
super().__init__(quant_config)
|
||||
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
assert self.quant_config.use_fp8_w8a8
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@ -250,6 +251,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_packed_ue8m0_act_scales(self) -> bool:
|
||||
"""
|
||||
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
|
||||
"""
|
||||
return current_platform.is_device_capability(100)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
@ -6,6 +6,7 @@ import deep_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
@ -20,6 +21,8 @@ from vllm.v1.worker.ubatching import (
|
||||
dbo_maybe_run_recv_hook,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
|
||||
@ -94,6 +97,29 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.handles: list[tuple | None] = [None, None]
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
# We don't have enough information to determine if we should dispatch
|
||||
# activation scales in a packed ue8m0 format during object construction
|
||||
# time. This setting is handled by post_init_setup.
|
||||
self.use_ue8m0_dispatch = False
|
||||
|
||||
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
|
||||
if not fused_experts.supports_packed_ue8m0_act_scales():
|
||||
# Early exit.
|
||||
return
|
||||
|
||||
if self.use_fp8_dispatch:
|
||||
logger.debug_once(
|
||||
"Update DeepEPLLPrepareFinalize to do packed ue8m0 scales dispatch."
|
||||
)
|
||||
self.use_ue8m0_dispatch = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepEPLLPrepareAndFinalize is setup to dispatch raw/unquantized "
|
||||
f"activations despite ({fused_experts.__class__.__name__}) being able "
|
||||
"to support quantized activations.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
@ -206,6 +232,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
# round_scale needs to be set to dispatch in ue8m0
|
||||
round_scale=self.use_ue8m0_dispatch,
|
||||
use_ue8m0=self.use_ue8m0_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
@ -149,6 +149,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
described above.
|
||||
"""
|
||||
|
||||
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
|
||||
"""
|
||||
Initialize FusedMoEPrepareAndFinalize settings that depend on
|
||||
FusedMoEPermuteExpertsUnpermute experts object.
|
||||
The FusedMoEPrepareAndFinalize implementations that have such
|
||||
dependencies may choose to override this function.
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
@ -503,6 +512,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_packed_ue8m0_act_scales(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class can process packed ue8m0
|
||||
activation scales.
|
||||
"""
|
||||
return False
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Workspace type: The dtype to use for the workspace tensors.
|
||||
@ -698,6 +714,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
|
||||
self._post_init_setup()
|
||||
assert (
|
||||
prepare_finalize.activation_format == fused_experts.activation_formats[0]
|
||||
), (
|
||||
@ -707,6 +725,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.activation_formats[0]}"
|
||||
)
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.fused_experts)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
|
||||
@ -60,11 +60,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
expert_weight_is_col_major,
|
||||
deepgemm_post_process_fp8_weight_block,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy,
|
||||
requant_weight_ue8m0_inplace,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
@ -94,7 +93,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
@ -846,15 +844,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.allow_deep_gemm:
|
||||
dg_w13_weight, dg_w13_weight_scale_inv = (
|
||||
deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.w13_weight.data,
|
||||
ws=layer.w13_weight_scale_inv.data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale_inv
|
||||
)
|
||||
dg_w2_weight, dg_w2_weight_scale_inv = (
|
||||
deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.w2_weight.data,
|
||||
ws=layer.w2_weight_scale_inv.data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
)
|
||||
layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = Parameter(
|
||||
dg_w13_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(
|
||||
dg_w2_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
@ -990,31 +1004,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
if is_deep_gemm_e8m0_used() and self.block_quant:
|
||||
assert layer.weight_block_size is not None
|
||||
# Re-quantise the expert weights so their scales are UE8M0.
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w13_weight.data,
|
||||
layer.w13_weight_scale_inv.data,
|
||||
block_sz,
|
||||
)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w2_weight.data,
|
||||
layer.w2_weight_scale_inv.data,
|
||||
block_sz,
|
||||
)
|
||||
|
||||
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale_inv
|
||||
)
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale_inv
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.rocm_aiter_moe_enabled
|
||||
@ -1037,7 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
@ -1053,20 +1043,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
experts_impl = (
|
||||
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
|
||||
)
|
||||
logger.debug(
|
||||
"BatchedTritonOrDeepGemmExperts(%s): "
|
||||
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||
experts_impl.__name__,
|
||||
self.__class__.__name__,
|
||||
max_num_tokens_per_rank,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
return experts_impl(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
self.moe,
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
transform_sf_into_required_layout,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@ -929,6 +930,50 @@ def requant_weight_ue8m0_inplace(
|
||||
s_old.copy_(s_requant)
|
||||
|
||||
|
||||
def deepgemm_post_process_fp8_weight_block(
|
||||
wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert wq.dtype == torch.float8_e4m3fn, (
|
||||
"Expected quantized tensor dtype "
|
||||
f"to be torch.float8_e4m3fn, got {wq.dtype} instead."
|
||||
)
|
||||
assert ws.dtype == torch.float32, (
|
||||
f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead"
|
||||
)
|
||||
|
||||
if use_e8m0:
|
||||
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)
|
||||
|
||||
original_ndim = wq.ndim
|
||||
if wq.ndim == 2:
|
||||
assert ws.ndim == 2
|
||||
wq = wq.unsqueeze(0)
|
||||
ws = ws.unsqueeze(0)
|
||||
|
||||
# From https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/utils/layout.hpp#L46
|
||||
recipe = (1, 128, 128)
|
||||
|
||||
# Ref : https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/apis/gemm.hpp
|
||||
# DeepGemm uses the `transform_sf_into_required_layout` function to
|
||||
# represent scales in the correct format.
|
||||
dg_ws = transform_sf_into_required_layout(
|
||||
sf=ws,
|
||||
mn=wq.size(1),
|
||||
k=wq.size(2),
|
||||
recipe=recipe,
|
||||
num_groups=wq.size(0),
|
||||
# is the scale factors for A in (Refers to the argument A in A @ B).
|
||||
# Weights are B.
|
||||
is_sfa=False,
|
||||
)
|
||||
|
||||
if original_ndim == 2:
|
||||
wq = wq.squeeze(0)
|
||||
dg_ws = dg_ws.squeeze(0)
|
||||
|
||||
return wq, dg_ws
|
||||
|
||||
|
||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
can benefit from tensors located far enough from one another in memory"""
|
||||
@ -1141,11 +1186,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
|
||||
layer.orig_dtype, layer.weight
|
||||
)
|
||||
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.weight.data, layer.weight_scale.data, block_sz
|
||||
if should_use_deepgemm:
|
||||
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.weight.data,
|
||||
ws=layer.weight_scale.data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
|
||||
@ -49,10 +49,6 @@ def is_deep_gemm_e8m0_used() -> bool:
|
||||
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
||||
return False
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8:
|
||||
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
|
||||
return False
|
||||
|
||||
if envs.VLLM_USE_DEEP_GEMM_E8M0:
|
||||
logger.info_once("DeepGEMM E8M0 enabled on current platform.")
|
||||
return True
|
||||
@ -77,6 +73,7 @@ _fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
||||
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
@ -86,6 +83,7 @@ def _lazy_init() -> None:
|
||||
global _get_paged_mqa_logits_metadata_impl
|
||||
global _get_mn_major_tma_aligned_tensor_impl
|
||||
global _get_mk_alignment_for_contiguous_layout_impl
|
||||
global _transform_sf_into_required_layout_impl
|
||||
# fast path
|
||||
if (
|
||||
_fp8_gemm_nt_impl is not None
|
||||
@ -95,6 +93,7 @@ def _lazy_init() -> None:
|
||||
or _fp8_paged_mqa_logits_impl is not None
|
||||
or _get_paged_mqa_logits_metadata_impl is not None
|
||||
or _get_mk_alignment_for_contiguous_layout_impl is not None
|
||||
or _transform_sf_into_required_layout_impl is not None
|
||||
):
|
||||
return
|
||||
|
||||
@ -124,6 +123,9 @@ def _lazy_init() -> None:
|
||||
_get_mk_alignment_for_contiguous_layout_impl = getattr(
|
||||
_dg, "get_mk_alignment_for_contiguous_layout", None
|
||||
)
|
||||
_transform_sf_into_required_layout_impl = getattr(
|
||||
_dg, "transform_sf_into_required_layout", None
|
||||
)
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
@ -179,6 +181,15 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
def transform_sf_into_required_layout(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _transform_sf_into_required_layout_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _transform_sf_into_required_layout_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user