mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:35:52 +08:00
[Bugfix][Wide EP] Fix redundant work when using DeepEP, TP Attn, and EP MoE (#24134)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
parent
4f87abdcc6
commit
955c624915
@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
round_up)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
@ -786,6 +786,7 @@ class FusedMoE(CustomOp):
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@ -797,6 +798,10 @@ class FusedMoE(CustomOp):
|
||||
dp_size_ = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
if self.is_sequence_parallel:
|
||||
self.sp_size = tp_size_
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
@ -1699,14 +1704,22 @@ class FusedMoE(CustomOp):
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
# If the input to the MoE is sequence parallel then divide by sp_size
|
||||
# to find the maximum number of tokens for any individual dispatcher.
|
||||
if self.is_sequence_parallel:
|
||||
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
|
||||
self.sp_size)
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
|
||||
range(0, max_tokens_across_dispatchers,
|
||||
moe_dp_chunk_size_per_rank)):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
||||
max_tokens_across_dp)
|
||||
max_tokens_across_dispatchers)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
|
||||
@ -37,8 +37,6 @@ class DeepseekV2Model(nn.Module):
|
||||
super().__init__()
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
@ -51,11 +49,8 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
DeepseekV2DecoderLayer(
|
||||
self.config,
|
||||
vllm_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
])
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -43,23 +43,19 @@ class SharedHead(nn.Module):
|
||||
|
||||
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
|
||||
cache_config, quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -95,13 +91,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
DeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
DeepSeekMultiTokenPredictorLayer(vllm_config,
|
||||
f"{prefix}.layers.{idx}")
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
|
||||
@ -32,12 +32,14 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -55,7 +57,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
@ -72,19 +76,27 @@ class DeepseekV2MLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
is_sequence_parallel=False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# If is_sequence_parallel, the input and output tensors are sharded
|
||||
# across the ranks within the tp_group. In this case the weights are
|
||||
# replicated and no collective ops are needed.
|
||||
# Otherwise we use standard TP with an allreduce at the end.
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
disable_tp=is_sequence_parallel,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
disable_tp=is_sequence_parallel,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@ -98,17 +110,58 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
|
||||
chunk = x.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(x, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk",
|
||||
op_func=sequence_parallel_chunk,
|
||||
mutates_args=[],
|
||||
fake_impl=sequence_parallel_chunk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
parallel_config: ParallelConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
@ -117,6 +170,21 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("deepep_high_throughput",
|
||||
"deepep_low_latency")
|
||||
and parallel_config.enable_expert_parallel
|
||||
and self.tp_size > 1)
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -133,9 +201,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
@ -166,7 +233,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
)
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
@ -177,6 +246,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
@ -199,11 +269,22 @@ class DeepseekV2MoE(nn.Module):
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Chunk the hidden states so they aren't replicated across TP ranks.
|
||||
# This avoids duplicate computation in self.experts.
|
||||
# TODO: We can replace the all_reduce at the end of attn with a
|
||||
# reduce_scatter instead of chunking here.
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
|
||||
hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
@ -228,7 +309,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
elif self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
@ -532,16 +617,15 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@ -578,9 +662,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
and layer_idx % config.moe_layer_freq == 0):
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
parallel_config=parallel_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
@ -650,10 +734,7 @@ class DeepseekV2Model(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
enable_eplb = vllm_config.parallel_config.enable_eplb
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -669,14 +750,7 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekV2DecoderLayer(
|
||||
config,
|
||||
prefix,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
enable_eplb=enable_eplb,
|
||||
),
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user