mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 15:11:18 +08:00
DP/EP Support for gpt-oss with deepep-ht comm kernel on SM100 (#23608)
This commit is contained in:
parent
853c371fc3
commit
082cc07ef8
@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
|
|||||||
if module.__class__.__name__ == "FusedMoE"
|
if module.__class__.__name__ == "FusedMoE"
|
||||||
]
|
]
|
||||||
for module in moe_modules:
|
for module in moe_modules:
|
||||||
module.quant_method.init_prepare_finalize()
|
module.quant_method.init_prepare_finalize(module)
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self, hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -450,6 +450,12 @@ class FusedMoEConfig:
|
|||||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||||
quant_dtype = torch.float8_e4m3fn
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||||
|
Mxfp4Config)
|
||||||
|
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
|
||||||
|
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
|
||||||
|
quant_dtype = "mxfp8"
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.modelopt import (
|
from vllm.model_executor.layers.quantization.modelopt import (
|
||||||
ModelOptNvFp4Config)
|
ModelOptNvFp4Config)
|
||||||
if quant_dtype is None and isinstance(quant_config,
|
if quant_dtype is None and isinstance(quant_config,
|
||||||
|
|||||||
@ -200,7 +200,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
|
|
||||||
# Note: init_prepare_finalize should only be called by
|
# Note: init_prepare_finalize should only be called by
|
||||||
# prepare_communication_buffer_for_model.
|
# prepare_communication_buffer_for_model.
|
||||||
def init_prepare_finalize(self):
|
def init_prepare_finalize(self, layer: torch.nn.Module):
|
||||||
assert self.moe is not None
|
assert self.moe is not None
|
||||||
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
assert self.fused_experts is None, \
|
assert self.fused_experts is None, \
|
||||||
f"Attempt to override experts for {id(self)}!"
|
f"Attempt to override experts for {id(self)}!"
|
||||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||||
experts = self.select_gemm_impl(prepare_finalize, self.moe)
|
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
|
||||||
self.fused_experts = FusedMoEModularKernel(
|
self.fused_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
experts,
|
experts,
|
||||||
@ -221,6 +221,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
# based on the all2all implementation, select the appropriate
|
# based on the all2all implementation, select the appropriate
|
||||||
# gemm implementation
|
# gemm implementation
|
||||||
@ -273,6 +274,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
# TODO(bnell): Remove. Every layer should have an moe config object.
|
# TODO(bnell): Remove. Every layer should have an moe config object.
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
if (prepare_finalize.activation_format ==
|
if (prepare_finalize.activation_format ==
|
||||||
FusedMoEActivationFormat.BatchedExperts):
|
FusedMoEActivationFormat.BatchedExperts):
|
||||||
|
|||||||
197
vllm/model_executor/layers/fused_moe/trtllm_moe.py
Normal file
197
vllm/model_executor/layers/fused_moe/trtllm_moe.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
|
TopKWeightAndReduceNoOP)
|
||||||
|
from vllm.utils import next_power_of_2
|
||||||
|
|
||||||
|
|
||||||
|
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
gemm1_alpha,
|
||||||
|
gemm1_beta,
|
||||||
|
gemm1_clamp_limit,
|
||||||
|
w13_bias,
|
||||||
|
w2_bias,
|
||||||
|
max_capture_size,
|
||||||
|
):
|
||||||
|
super().__init__(moe.quant_config)
|
||||||
|
self.moe = moe
|
||||||
|
self.gemm1_alpha = gemm1_alpha
|
||||||
|
self.gemm1_beta = gemm1_beta
|
||||||
|
self.gemm1_clamp_limit = gemm1_clamp_limit
|
||||||
|
self.w13_bias = w13_bias
|
||||||
|
self.w2_bias = w2_bias
|
||||||
|
self.max_capture_size = max_capture_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
|
|
||||||
|
def supports_chunking(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
return TopKWeightAndReduceNoOP()
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
|
# The workspaces for this implementation are managed by flashinfer.
|
||||||
|
# TODO(varun) : workspace1 is could be used as the output tensor. This
|
||||||
|
# is error-prone. Allow the `workspace_shapes` to return None workspaces
|
||||||
|
workspace1 = (M, K)
|
||||||
|
workspace2 = (0, 0)
|
||||||
|
output = (M, K)
|
||||||
|
return (workspace1, workspace2, output, a.dtype)
|
||||||
|
|
||||||
|
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
|
||||||
|
local_num_experts: int):
|
||||||
|
# Number of tokens in the input tensor.
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
# Factor to account for the imbalance of the experts.
|
||||||
|
# factor equals to the
|
||||||
|
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||||
|
# 1.0 means perfect expert distribution.
|
||||||
|
# > 1.0 means some experts have more tokens than the perfect
|
||||||
|
# distribution.
|
||||||
|
# < 1.0 does not make sense.
|
||||||
|
imbalance_factor = 1.3
|
||||||
|
# Calculate the number of tokens per expert assuming perfect
|
||||||
|
# distribution.
|
||||||
|
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
|
||||||
|
# Apply the imbalance factor.
|
||||||
|
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||||
|
# And pad the number to the next power of 2.
|
||||||
|
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||||
|
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
|
||||||
|
# kernel.
|
||||||
|
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||||
|
|
||||||
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
|
topk = topk_ids.size(-1)
|
||||||
|
local_num_experts = w1.size(0)
|
||||||
|
intermediate_size = w2.size(1)
|
||||||
|
local_expert_offset = self.moe.ep_rank * local_num_experts
|
||||||
|
|
||||||
|
x_quant = hidden_states
|
||||||
|
x_scale = a1q_scale
|
||||||
|
if x_scale is not None:
|
||||||
|
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
|
||||||
|
*x_quant.shape[:-1], -1)
|
||||||
|
|
||||||
|
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||||
|
torch.bfloat16).view(torch.int16)
|
||||||
|
|
||||||
|
assert w1_scale is not None
|
||||||
|
assert w2_scale is not None
|
||||||
|
kwargs = {
|
||||||
|
"topk_ids":
|
||||||
|
packed_tensor,
|
||||||
|
"routing_bias":
|
||||||
|
None,
|
||||||
|
"hidden_states":
|
||||||
|
x_quant,
|
||||||
|
"hidden_states_scale":
|
||||||
|
x_scale,
|
||||||
|
"gemm1_weights":
|
||||||
|
w1,
|
||||||
|
"gemm1_weights_scale":
|
||||||
|
w1_scale,
|
||||||
|
"gemm1_bias":
|
||||||
|
self.w13_bias,
|
||||||
|
"gemm1_alpha":
|
||||||
|
self.gemm1_alpha,
|
||||||
|
"gemm1_beta":
|
||||||
|
self.gemm1_beta,
|
||||||
|
"gemm1_clamp_limit":
|
||||||
|
self.gemm1_clamp_limit,
|
||||||
|
"gemm2_weights":
|
||||||
|
w2,
|
||||||
|
"gemm2_weights_scale":
|
||||||
|
w2_scale,
|
||||||
|
"gemm2_bias":
|
||||||
|
self.w2_bias,
|
||||||
|
"output1_scale_scalar":
|
||||||
|
None,
|
||||||
|
"output1_scale_gate_scalar":
|
||||||
|
None,
|
||||||
|
"output2_scale_scalar":
|
||||||
|
None,
|
||||||
|
"num_experts":
|
||||||
|
global_num_experts,
|
||||||
|
"top_k":
|
||||||
|
topk,
|
||||||
|
"n_group":
|
||||||
|
None,
|
||||||
|
"topk_group":
|
||||||
|
None,
|
||||||
|
"intermediate_size":
|
||||||
|
intermediate_size,
|
||||||
|
"local_expert_offset":
|
||||||
|
local_expert_offset,
|
||||||
|
"local_num_experts":
|
||||||
|
local_num_experts,
|
||||||
|
"routed_scaling_factor":
|
||||||
|
None,
|
||||||
|
"tile_tokens_dim":
|
||||||
|
self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
|
||||||
|
"routing_method_type":
|
||||||
|
1,
|
||||||
|
"do_finalize":
|
||||||
|
True,
|
||||||
|
"output":
|
||||||
|
output,
|
||||||
|
"tune_max_num_tokens":
|
||||||
|
self.max_capture_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
from flashinfer import trtllm_fp4_block_scale_routed_moe
|
||||||
|
trtllm_fp4_block_scale_routed_moe(**kwargs)
|
||||||
|
return output
|
||||||
@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
|||||||
per_token_group_quant_int8, per_token_quant_int8)
|
per_token_group_quant_int8, per_token_quant_int8)
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
quant_dequant_mxfp4)
|
quant_dequant_mxfp4)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||||
|
mxfp8_quantize)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
@ -177,6 +179,18 @@ def _mxfp4_quantize(
|
|||||||
return A, None
|
return A, None
|
||||||
|
|
||||||
|
|
||||||
|
def _mxfp8_quantize(
|
||||||
|
A: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert A_scale is None
|
||||||
|
assert not per_act_token_quant
|
||||||
|
assert block_shape is None
|
||||||
|
return mxfp8_quantize(A)
|
||||||
|
|
||||||
|
|
||||||
def moe_kernel_quantize_input(
|
def moe_kernel_quantize_input(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
A_scale: Optional[torch.Tensor],
|
||||||
@ -195,6 +209,8 @@ def moe_kernel_quantize_input(
|
|||||||
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
||||||
elif quant_dtype == "mxfp4":
|
elif quant_dtype == "mxfp4":
|
||||||
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
|
elif quant_dtype == "mxfp8":
|
||||||
|
return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
else:
|
else:
|
||||||
return A, A_scale
|
return A, A_scale
|
||||||
|
|
||||||
|
|||||||
@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
self,
|
self,
|
||||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
"""Return the appropriate GEMM experts implementation."""
|
"""Return the appropriate GEMM experts implementation."""
|
||||||
experts = select_nvfp4_gemm_impl(
|
experts = select_nvfp4_gemm_impl(
|
||||||
@ -719,10 +720,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
moe: FusedMoEConfig,
|
||||||
moe: FusedMoEConfig,
|
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
|
||||||
# cutlass path
|
# cutlass path
|
||||||
if self.use_cutlass:
|
if self.use_cutlass:
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
|
|||||||
@ -897,6 +897,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)
|
BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)
|
||||||
|
|||||||
@ -311,6 +311,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
experts = select_cutlass_fp8_gemm_impl(
|
experts = select_cutlass_fp8_gemm_impl(
|
||||||
moe,
|
moe,
|
||||||
@ -1032,6 +1033,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
experts = select_nvfp4_gemm_impl(
|
experts = select_nvfp4_gemm_impl(
|
||||||
moe,
|
moe,
|
||||||
|
|||||||
@ -10,6 +10,8 @@ from vllm.config import get_current_vllm_config
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
|
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -445,6 +447,91 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
return tile_tokens_dim
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
if (prepare_finalize.activation_format ==
|
||||||
|
mk.FusedMoEActivationFormat.BatchedExperts):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Mxfp4 does not support batched experts format for EP")
|
||||||
|
else:
|
||||||
|
if should_use_flashinfer_mxfp4():
|
||||||
|
# B200 code-path
|
||||||
|
kwargs = {
|
||||||
|
"gemm1_alpha": layer.gemm1_alpha,
|
||||||
|
"gemm1_beta": layer.gemm1_beta,
|
||||||
|
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
||||||
|
"w13_bias": layer.w13_bias,
|
||||||
|
"w2_bias": layer.w2_bias,
|
||||||
|
"max_capture_size": self.max_capture_size,
|
||||||
|
}
|
||||||
|
return TrtLlmGenExperts(moe, **kwargs)
|
||||||
|
else:
|
||||||
|
# Use matmul_ogs from triton_kernels here!
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Mxfp4 does not support non-batched experts format for EP")
|
||||||
|
|
||||||
|
def _route_and_experts(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: Optional[torch.Tensor] = None,
|
||||||
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
|
logical_replica_count: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype,
|
||||||
|
enable_eplb=enable_eplb,
|
||||||
|
expert_map=expert_map,
|
||||||
|
expert_load_view=expert_load_view,
|
||||||
|
logical_to_physical_map=logical_to_physical_map,
|
||||||
|
logical_replica_count=logical_replica_count)
|
||||||
|
|
||||||
|
return self.fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -503,6 +590,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
if self.fused_experts is not None:
|
||||||
|
return self._route_and_experts(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize,
|
||||||
|
use_grouped_topk,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
custom_routing_function,
|
||||||
|
scoring_func,
|
||||||
|
e_score_correction_bias,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
activation,
|
||||||
|
enable_eplb,
|
||||||
|
expert_load_view,
|
||||||
|
logical_to_physical_map,
|
||||||
|
logical_replica_count,
|
||||||
|
)
|
||||||
|
|
||||||
assert _can_support_mxfp4(
|
assert _can_support_mxfp4(
|
||||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||||
custom_routing_function, e_score_correction_bias,
|
custom_routing_function, e_score_correction_bias,
|
||||||
|
|||||||
@ -66,11 +66,10 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None):
|
logical_replica_count: Optional[torch.Tensor] = None):
|
||||||
return not (use_grouped_topk or topk_group or num_expert_group
|
return not (use_grouped_topk or topk_group or num_expert_group
|
||||||
or expert_map or custom_routing_function
|
or custom_routing_function or e_score_correction_bias
|
||||||
or e_score_correction_bias or apply_router_weight_on_input
|
or apply_router_weight_on_input or scoring_func != "softmax"
|
||||||
or scoring_func != "softmax" or activation != "swigluoai"
|
or activation != "swigluoai" or expert_load_view
|
||||||
or expert_load_view or logical_to_physical_map
|
or logical_to_physical_map or logical_replica_count)
|
||||||
or logical_replica_count)
|
|
||||||
|
|
||||||
|
|
||||||
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
||||||
|
|||||||
20
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
Normal file
20
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flashinfer import mxfp8_quantize
|
||||||
|
except ImportError as err:
|
||||||
|
raise ImportError("The package `flashinfer` is required to do "
|
||||||
|
"MX-FP8 quantization. Please install it with" \
|
||||||
|
"`pip install flashinfer`") from err
|
||||||
|
|
||||||
|
return mxfp8_quantize(x, is_sf_swizzled_layout=False)
|
||||||
Loading…
x
Reference in New Issue
Block a user