mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 05:55:01 +08:00
[DP/EP][GPTOSS] Use triton matmul-ogs kernels for GPTOSS DP/EP (#24588)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
fafbe11af4
commit
e8db44f883
@ -288,7 +288,11 @@ class FusedMoEQuantConfig:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_mxfp4_w4a4(self) -> bool:
|
def use_mxfp4_w4a4(self) -> bool:
|
||||||
return self.quant_dtype == "mxfp4"
|
return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_mxfp4_w4a16(self) -> bool:
|
||||||
|
return (self._a1.dtype is None and self._w1.dtype == "mxfp4")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_nvfp4_w4a4(self) -> bool:
|
def use_nvfp4_w4a4(self) -> bool:
|
||||||
@ -453,6 +457,22 @@ def int8_w8a8_moe_quant_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mxfp4_w4a16_moe_quant_config(
|
||||||
|
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||||
|
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for unquantized activations and mxfp4 weights.
|
||||||
|
"""
|
||||||
|
return FusedMoEQuantConfig(
|
||||||
|
_a1=FusedMoEQuantDesc(),
|
||||||
|
_a2=FusedMoEQuantDesc(),
|
||||||
|
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||||
|
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mxfp4_w4a4_moe_quant_config(
|
def mxfp4_w4a4_moe_quant_config(
|
||||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|||||||
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
|
|
||||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
@ -18,6 +19,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
Prepare/Finalize using DeepEP High-Throughput kernels.
|
Prepare/Finalize using DeepEP High-Throughput kernels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_roundup_layer_hidden_size(hidden_size: int,
|
||||||
|
dtype: torch.dtype) -> int:
|
||||||
|
# Round up hidden size so it is compatible with DeepEP High Throughput
|
||||||
|
# kernels.
|
||||||
|
# DeepEP intranode kernels make copies in units of,
|
||||||
|
# 32(warp-size) int4 elements. Round up hidden size to respect this.
|
||||||
|
# For example, an input hidden size of 2880 with dtype torch.bfloat16
|
||||||
|
# will be rounded up to 3072.
|
||||||
|
hidden_size_bytes = hidden_size * dtype.itemsize
|
||||||
|
xfer_atom_size = 512 # 32 * 16 (size(int4))
|
||||||
|
if hidden_size_bytes % xfer_atom_size == 0:
|
||||||
|
return hidden_size
|
||||||
|
|
||||||
|
hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
|
||||||
|
return hidden_size_bytes // dtype.itemsize
|
||||||
|
|
||||||
def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
|
def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
|
||||||
dp_size: int, rank_expert_offset: int):
|
dp_size: int, rank_expert_offset: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
|
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceNoOP)
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import has_triton_kernels
|
from vllm.utils import has_triton_kernels
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -19,13 +20,55 @@ if has_triton_kernels():
|
|||||||
import triton_kernels.swiglu
|
import triton_kernels.swiglu
|
||||||
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
||||||
matmul_ogs)
|
matmul_ogs)
|
||||||
from triton_kernels.routing import routing
|
from triton_kernels.routing import (RoutingData, routing,
|
||||||
|
routing_from_bitmatrix)
|
||||||
|
from triton_kernels.tensor import Bitmatrix
|
||||||
except (ModuleNotFoundError, AttributeError) as e:
|
except (ModuleNotFoundError, AttributeError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to import Triton kernels. Please make sure your triton "
|
"Failed to import Triton kernels. Please make sure your triton "
|
||||||
"version is compatible. Error: %s", e)
|
"version is compatible. Error: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def pack_bitmatrix(
|
||||||
|
bitmatrix,
|
||||||
|
topk_ids,
|
||||||
|
n_rows, # n_rows in bitmatrix / topk_ids
|
||||||
|
bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix
|
||||||
|
n_expts_act, # num_topk
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Packs topk_ids into a bitmatrix.
|
||||||
|
code reference:
|
||||||
|
https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264
|
||||||
|
"""
|
||||||
|
pid_m = tl.program_id(0)
|
||||||
|
offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offsets_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
|
||||||
|
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
|
||||||
|
indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
|
||||||
|
div = indices // 32
|
||||||
|
rem = indices % 32
|
||||||
|
one = tl.cast(1, tl.uint32)
|
||||||
|
|
||||||
|
# Iterate through all the relevant bitmatrix columns.
|
||||||
|
for i in range(bm_cols):
|
||||||
|
# When BLOCK_SIZE_K=32, offs is just the column index.
|
||||||
|
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
|
||||||
|
# All topks that need to go into this column has the correct bit set.
|
||||||
|
# Other bits are 0. x is a 2D tensor.
|
||||||
|
x = tl.where(div[:, :, None] == offs[None, None, :],
|
||||||
|
(one << rem)[:, :, None], 0)
|
||||||
|
# Reduce x to get a single int32_t bitpack.
|
||||||
|
y = tl.reduce_or(x, axis=1)
|
||||||
|
bitmatrix_ptrs = bitmatrix + offsets_m[:,
|
||||||
|
None] * bm_cols + offs[None, :]
|
||||||
|
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
|
||||||
|
|
||||||
|
|
||||||
def triton_kernel_moe_forward(
|
def triton_kernel_moe_forward(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1, # Tensor or triton_kernels.Tensor
|
w1, # Tensor or triton_kernels.Tensor
|
||||||
@ -124,34 +167,88 @@ def triton_kernel_fused_experts(
|
|||||||
return intermediate_cache3
|
return intermediate_cache3
|
||||||
|
|
||||||
|
|
||||||
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
def make_routing_data(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
num_local_experts: int,
|
||||||
|
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
def __init__(
|
topk_ids = topk_ids.to(torch.int16)
|
||||||
self,
|
topk_weights = topk_weights.to(torch.bfloat16)
|
||||||
max_num_tokens: int,
|
|
||||||
num_dispatchers: int,
|
n_rows, num_topk = topk_ids.size()
|
||||||
quant_config: FusedMoEQuantConfig,
|
|
||||||
):
|
BLOCK_SIZE_M = 512
|
||||||
|
BLOCK_SIZE_K = 32
|
||||||
|
|
||||||
|
bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks
|
||||||
|
bitmatrix = torch.zeros((n_rows, bm_cols),
|
||||||
|
dtype=torch.uint32,
|
||||||
|
device=topk_ids.device)
|
||||||
|
|
||||||
|
grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), )
|
||||||
|
pack_bitmatrix[grid](
|
||||||
|
bitmatrix,
|
||||||
|
topk_ids,
|
||||||
|
n_rows,
|
||||||
|
bm_cols,
|
||||||
|
num_topk,
|
||||||
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||||
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
|
)
|
||||||
|
|
||||||
|
bitmatrix_shape = [n_rows, bm_cols * 32]
|
||||||
|
bitmatrix_shape_max = [n_rows, None]
|
||||||
|
bitmatrix = Bitmatrix(bitmatrix,
|
||||||
|
shape=bitmatrix_shape,
|
||||||
|
shape_max=bitmatrix_shape_max,
|
||||||
|
scratchpad=None)
|
||||||
|
|
||||||
|
# matmul_ogs expects invalid topk_weights to be -1s
|
||||||
|
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
||||||
|
routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
|
||||||
|
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk)
|
||||||
|
|
||||||
|
return routing_data, gather_indx, scatter_indx
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||||
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
# Weight application and reduction happens in the fused_experts kernel.
|
||||||
|
return TopKWeightAndReduceNoOP()
|
||||||
|
|
||||||
|
def _make_routing_data(
|
||||||
|
self,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
num_local_experts: int,
|
||||||
|
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
|
||||||
|
return make_routing_data(topk_ids, topk_weights, num_local_experts)
|
||||||
|
|
||||||
|
|
||||||
|
class OAITritonExperts(BaseOAITritonExperts):
|
||||||
|
|
||||||
|
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||||
|
# TODO (varun) : Enable activation quantization
|
||||||
|
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
self.max_num_tokens = max_num_tokens
|
|
||||||
self.num_dispatchers = num_dispatchers
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
self
|
self
|
||||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return False
|
return True
|
||||||
|
|
||||||
def supports_expert_map(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
|
||||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
|
||||||
return TopKWeightAndReduceDelegate()
|
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||||
@ -159,13 +256,10 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
# workspace are allocated inside the kernel
|
# workspace are allocated inside the kernel
|
||||||
assert a.dim() == 2
|
workspace1 = (M, K)
|
||||||
num_dp = self.num_dispatchers
|
workspace2 = (0, 0)
|
||||||
num_experts = local_num_experts
|
output = (M, K)
|
||||||
max_num_tokens = self.max_num_tokens
|
return (workspace1, workspace2, output, a.dtype)
|
||||||
workspace2 = (0, 0, 0)
|
|
||||||
output = (num_experts, max_num_tokens * num_dp, N)
|
|
||||||
return (output, workspace2, output, a.dtype)
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -185,17 +279,29 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
):
|
):
|
||||||
return triton_kernel_fused_experts(
|
if expert_map is not None:
|
||||||
output,
|
topk_ids = expert_map[topk_ids]
|
||||||
|
|
||||||
|
local_num_experts = w1.size(0)
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = local_num_experts
|
||||||
|
|
||||||
|
routing_data, gather_indx, scatter_indx = self._make_routing_data(
|
||||||
|
topk_ids, topk_weights, local_num_experts)
|
||||||
|
|
||||||
|
experts_output = triton_kernel_fused_experts(
|
||||||
|
None,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
routing_data=None,
|
routing_data,
|
||||||
gather_indx=None,
|
gather_indx,
|
||||||
scatter_indx=None,
|
scatter_indx,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
apply_router_weight_on_input=False,
|
apply_router_weight_on_input=False,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=local_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=None, # applied already
|
||||||
a1q_scale=a1q_scale)
|
a1q_scale=a1q_scale)
|
||||||
|
|
||||||
|
output.copy_(experts_output, non_blocking=True)
|
||||||
|
|||||||
@ -800,6 +800,49 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
|
|||||||
for local_index, global_index in zip(local_indices, global_indices))
|
for local_index, global_index in zip(local_indices, global_indices))
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_roundup_hidden_size(
|
||||||
|
hidden_size: int, act_dtype: torch.dtype,
|
||||||
|
quant_config: Optional[QuantizationConfig],
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig) -> int:
|
||||||
|
"""
|
||||||
|
Given layer hidden size and MoE configurations, round up hidden_size
|
||||||
|
if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_size(int): Layer hidden-size
|
||||||
|
act_dtype: Data type of the layer activations.
|
||||||
|
quant_config(FusedMoEQuantConfig): Fused MoE quantization configuration.
|
||||||
|
moe_parallel_config(FusedMoEParallelConfig): Fused MoE parallelization
|
||||||
|
strategy configuration.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Rounded up hidden_size if rounding up is required based on the configs.
|
||||||
|
Original hidden size otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (moe_parallel_config.use_deepep_ht_kernels):
|
||||||
|
hidden_size = (
|
||||||
|
DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||||
|
hidden_size, act_dtype))
|
||||||
|
|
||||||
|
# we are padding globally so EP buffer allocation works
|
||||||
|
if quant_config and quant_config.get_name() == "mxfp4":
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||||
|
Mxfp4Backend, get_mxfp4_backend)
|
||||||
|
current_mxfp4_backend = get_mxfp4_backend()
|
||||||
|
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||||
|
or current_mxfp4_backend
|
||||||
|
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
|
||||||
|
hidden_size = round_up(hidden_size, 128)
|
||||||
|
elif (current_platform.is_rocm() or current_mxfp4_backend
|
||||||
|
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||||
|
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||||
|
hidden_size = round_up(hidden_size, 256)
|
||||||
|
|
||||||
|
return hidden_size
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("fused_moe")
|
@CustomOp.register("fused_moe")
|
||||||
class FusedMoE(CustomOp):
|
class FusedMoE(CustomOp):
|
||||||
"""FusedMoE layer for MoE models.
|
"""FusedMoE layer for MoE models.
|
||||||
@ -856,6 +899,18 @@ class FusedMoE(CustomOp):
|
|||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
|
# FIXME (varun): We should have a better way of inferring the activation
|
||||||
|
# datatype. This works for now as the tensor datatype entering the MoE
|
||||||
|
# operation is typically unquantized (i.e. float16/bfloat16).
|
||||||
|
if vllm_config.model_config is not None:
|
||||||
|
moe_in_dtype = vllm_config.model_config.dtype
|
||||||
|
else:
|
||||||
|
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||||
|
# since model_config is not set in the pytest test.
|
||||||
|
moe_in_dtype = params_dtype
|
||||||
|
|
||||||
tp_size_ = (tp_size if tp_size is not None else
|
tp_size_ = (tp_size if tp_size is not None else
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
dp_size_ = (dp_size
|
dp_size_ = (dp_size
|
||||||
@ -865,7 +920,6 @@ class FusedMoE(CustomOp):
|
|||||||
if self.is_sequence_parallel:
|
if self.is_sequence_parallel:
|
||||||
self.sp_size = tp_size_
|
self.sp_size = tp_size_
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||||
FusedMoEParallelConfig.make(
|
FusedMoEParallelConfig.make(
|
||||||
tp_size_=tp_size_,
|
tp_size_=tp_size_,
|
||||||
@ -874,19 +928,10 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
self.global_num_experts = num_experts + num_redundant_experts
|
self.global_num_experts = num_experts + num_redundant_experts
|
||||||
|
|
||||||
# we are padding globally so EP buffer allocation works
|
# Round up hidden size if needed.
|
||||||
if quant_config and quant_config.get_name() == "mxfp4":
|
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
|
||||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
quant_config,
|
||||||
Mxfp4Backend, get_mxfp4_backend)
|
self.moe_parallel_config)
|
||||||
current_mxfp4_backend = get_mxfp4_backend()
|
|
||||||
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
|
||||||
or current_mxfp4_backend
|
|
||||||
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
|
|
||||||
hidden_size = round_up(hidden_size, 128)
|
|
||||||
elif (current_platform.is_rocm() or current_mxfp4_backend
|
|
||||||
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
|
|
||||||
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
|
||||||
hidden_size = round_up(hidden_size, 256)
|
|
||||||
|
|
||||||
# For smuggling this layer into the fused moe custom op
|
# For smuggling this layer into the fused moe custom op
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
@ -967,20 +1012,13 @@ class FusedMoE(CustomOp):
|
|||||||
raise ValueError("Only softmax scoring function is supported for "
|
raise ValueError("Only softmax scoring function is supported for "
|
||||||
"non-grouped topk.")
|
"non-grouped topk.")
|
||||||
|
|
||||||
if vllm_config.model_config is not None:
|
|
||||||
model_dtype = vllm_config.model_config.dtype
|
|
||||||
else:
|
|
||||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
|
||||||
# since model_config is not set in the pytest test.
|
|
||||||
model_dtype = params_dtype
|
|
||||||
|
|
||||||
moe = FusedMoEConfig(
|
moe = FusedMoEConfig(
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
hidden_dim=hidden_size,
|
hidden_dim=hidden_size,
|
||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
in_dtype=model_dtype,
|
in_dtype=moe_in_dtype,
|
||||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||||
has_bias=has_bias,
|
has_bias=has_bias,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -76,7 +76,7 @@ def _moe_problem_size(
|
|||||||
"""
|
"""
|
||||||
assert w1.dim() == 3 and w2.dim() == 3
|
assert w1.dim() == 3 and w2.dim() == 3
|
||||||
E, N, _ = w1.size()
|
E, N, _ = w1.size()
|
||||||
K = w2.size(1)
|
K = a1.size(-1)
|
||||||
|
|
||||||
if a1.dim() == 2:
|
if a1.dim() == 2:
|
||||||
# Make sure we are using the correct a1 (pre-permute).
|
# Make sure we are using the correct a1 (pre-permute).
|
||||||
|
|||||||
@ -13,7 +13,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
|||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
|
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
|
||||||
|
mxfp4_w4a16_moe_quant_config)
|
||||||
|
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||||
|
OAITritonExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@ -578,9 +581,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||||
|
|
||||||
# FIXME warp need to be adjusted based on batch size
|
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
|
||||||
# only apply to batched mode
|
# (stored in self.fused_experts) to determine if the MoE has a
|
||||||
if self.moe.use_ep:
|
# batched activation format. As self.fused_experts is not
|
||||||
|
# initialized at this point, we resort to checking the MoE config
|
||||||
|
# directly.
|
||||||
|
is_batched_moe = (self.moe.use_pplx_kernels
|
||||||
|
or self.moe.use_deepep_ll_kernels)
|
||||||
|
if is_batched_moe:
|
||||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||||
else:
|
else:
|
||||||
num_warps = 8
|
num_warps = 8
|
||||||
@ -640,10 +648,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||||
w1_scale = self.w13_precision_config
|
w1_scale = self.w13_precision_config
|
||||||
w2_scale = self.w2_precision_config
|
w2_scale = self.w2_precision_config
|
||||||
|
return mxfp4_w4a16_moe_quant_config(
|
||||||
|
w1_bias=layer.w13_bias,
|
||||||
|
w2_bias=layer.w2_bias,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
w1_scale = layer.w13_weight_scale
|
w1_scale = layer.w13_weight_scale
|
||||||
w2_scale = layer.w2_weight_scale
|
w2_scale = layer.w2_weight_scale
|
||||||
|
|
||||||
return mxfp4_w4a4_moe_quant_config(
|
return mxfp4_w4a4_moe_quant_config(
|
||||||
w1_bias=layer.w13_bias,
|
w1_bias=layer.w13_bias,
|
||||||
w2_bias=layer.w2_bias,
|
w2_bias=layer.w2_bias,
|
||||||
@ -661,6 +674,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Mxfp4 does not support batched experts format for EP")
|
"Mxfp4 does not support batched experts format for EP")
|
||||||
else:
|
else:
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||||
# B200 code-path
|
# B200 code-path
|
||||||
@ -671,13 +685,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
# TODO(bnell): part of quant_config
|
# TODO(bnell): part of quant_config
|
||||||
"max_capture_size": self.max_capture_size,
|
"max_capture_size": self.max_capture_size,
|
||||||
}
|
}
|
||||||
assert self.moe_quant_config is not None
|
|
||||||
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
else:
|
else:
|
||||||
# Use matmul_ogs from triton_kernels here!
|
return OAITritonExperts(self.moe_quant_config)
|
||||||
raise NotImplementedError(
|
|
||||||
"Mxfp4 does not support non-batched experts format for EP")
|
|
||||||
|
|
||||||
def _route_and_experts(
|
def _route_and_experts(
|
||||||
self,
|
self,
|
||||||
@ -722,10 +733,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map=logical_to_physical_map,
|
logical_to_physical_map=logical_to_physical_map,
|
||||||
logical_replica_count=logical_replica_count)
|
logical_replica_count=logical_replica_count)
|
||||||
|
|
||||||
|
w13_weight = (self.w13_weight_triton_tensor
|
||||||
|
if layer.w13_weight is None else layer.w13_weight)
|
||||||
|
w2_weight = (self.w2_weight_triton_tensor
|
||||||
|
if layer.w2_weight is None else layer.w2_weight)
|
||||||
|
assert all([w is not None for w in [w13_weight, w2_weight]])
|
||||||
|
|
||||||
return self.fused_experts(
|
return self.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user