[Bugfix][Wide EP] Fix redundant work when using DeepEP, TP Attn, and EP MoE (#24134)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith 2025-09-08 22:01:51 -04:00 committed by GitHub
parent 4f87abdcc6
commit 955c624915
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 132 additions and 59 deletions

View File

@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up) round_up)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
@ -786,6 +786,7 @@ class FusedMoE(CustomOp):
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
has_bias: bool = False, has_bias: bool = False,
is_sequence_parallel=False,
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
@ -797,6 +798,10 @@ class FusedMoE(CustomOp):
dp_size_ = (dp_size dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size) if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel:
self.sp_size = tp_size_
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = ( self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make( FusedMoEParallelConfig.make(
@ -1699,14 +1704,22 @@ class FusedMoE(CustomOp):
ctx = get_forward_context() ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP # flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
self.sp_size)
num_tokens = full_hidden_states.size(0) num_tokens = full_hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate( for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): range(0, max_tokens_across_dispatchers,
moe_dp_chunk_size_per_rank)):
chunk_start = chunk_start_ chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dp) max_tokens_across_dispatchers)
# clamp start and end # clamp start and end
chunk_start = min(chunk_start, num_tokens - 1) chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens) chunk_end = min(chunk_end, num_tokens)

View File

@ -37,8 +37,6 @@ class DeepseekV2Model(nn.Module):
super().__init__() super().__init__()
self.config = vllm_config. \ self.config = vllm_config. \
speculative_config.draft_model_config.hf_config speculative_config.draft_model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.vocab_size = self.config.vocab_size self.vocab_size = self.config.vocab_size
@ -51,11 +49,8 @@ class DeepseekV2Model(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
DeepseekV2DecoderLayer( DeepseekV2DecoderLayer(
self.config, vllm_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
) for i in range(self.config.num_hidden_layers) ) for i in range(self.config.num_hidden_layers)
]) ])

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -43,23 +43,19 @@ class SharedHead(nn.Module):
class DeepSeekMultiTokenPredictorLayer(nn.Module): class DeepSeekMultiTokenPredictorLayer(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config) self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
cache_config, quant_config)
def forward( def forward(
self, self,
@ -95,13 +91,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
# to map the exact layer index from weights # to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({ self.layers = torch.nn.ModuleDict({
str(idx): str(idx):
DeepSeekMultiTokenPredictorLayer( DeepSeekMultiTokenPredictorLayer(vllm_config,
config, f"{prefix}.layers.{idx}")
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx, for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers) self.mtp_start_layer_idx + self.num_mtp_layers)
}) })

View File

@ -32,12 +32,14 @@ import torch
from torch import nn from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig, from vllm.config import CacheConfig, ParallelConfig, VllmConfig
get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -55,7 +57,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -72,19 +76,27 @@ class DeepseekV2MLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj") prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj") prefix=f"{prefix}.down_proj")
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
@ -98,17 +110,58 @@ class DeepseekV2MLP(nn.Module):
return x return x
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
x = nn.functional.pad(x, (0, 0, 0, pad_len))
chunk = x.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(x, 0, start, chunk)
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
mutates_args=[],
fake_impl=sequence_parallel_chunk_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
self, self,
config: Union[DeepseekV2Config, DeepseekV3Config], config: Union[DeepseekV2Config, DeepseekV3Config],
parallel_config: ParallelConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
@ -117,6 +170,21 @@ class DeepseekV2MoE(nn.Module):
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
in ("deepep_high_throughput",
"deepep_low_latency")
and parallel_config.enable_expert_parallel
and self.tp_size > 1)
if config.hidden_act != "silu": if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. " raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@ -133,9 +201,8 @@ class DeepseekV2MoE(nn.Module):
self.gate.e_score_correction_bias = None self.gate.e_score_correction_bias = None
# Load balancing settings. # Load balancing settings.
vllm_config = get_current_vllm_config() eplb_config = parallel_config.eplb_config
eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = parallel_config.enable_eplb
self.enable_eplb = enable_eplb
self.n_redundant_experts = eplb_config.num_redundant_experts self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
@ -166,7 +233,9 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=1.0, routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
self.shared_experts = None self.shared_experts = None
else: else:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
@ -177,6 +246,7 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
is_sequence_parallel=self.is_sequence_parallel,
reduce_results=False, reduce_results=False,
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
@ -199,11 +269,22 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=1.0, routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
# Chunk the hidden states so they aren't replicated across TP ranks.
# This avoids duplicate computation in self.experts.
# TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here.
if self.is_sequence_parallel:
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
@ -228,7 +309,11 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None assert shared_output is not None
final_hidden_states += shared_output final_hidden_states += shared_output
if self.tp_size > 1: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = ( final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel( self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)) final_hidden_states))
@ -532,16 +617,15 @@ class DeepseekV2MLAAttention(nn.Module):
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
self,
config: Union[DeepseekV2Config, DeepseekV3Config],
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
enable_eplb: bool = False,
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
@ -578,9 +662,9 @@ class DeepseekV2DecoderLayer(nn.Module):
and layer_idx % config.moe_layer_freq == 0): and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE( self.mlp = DeepseekV2MoE(
config=config, config=config,
parallel_config=parallel_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
) )
else: else:
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
@ -650,10 +734,7 @@ class DeepseekV2Model(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
@ -669,14 +750,7 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer( lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
config,
prefix,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
enable_eplb=enable_eplb,
),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank: