[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Tyler Michael Smith 2025-09-27 10:22:28 -04:00 committed by yewentao256
parent 1e5e5d757e
commit e94aabe03d
23 changed files with 540 additions and 375 deletions

View File

@ -279,6 +279,24 @@ class ParallelConfig:
assert last_exc is not None
raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup,
has_unfinished: bool) -> bool:

View File

@ -6,7 +6,7 @@ import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
rank = self.rank if is_sequence_parallel else self.dp_rank
world_size = (self.world_size
if is_sequence_parallel else self.dp_world_size)
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
self.dp_group.broadcast(buffer[start:end, :], idx)
for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_sp_cpu[idx]
get_ep_group().broadcast(buffer[start:end, :], idx)
return buffer
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]
all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states
def destroy(self):
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError
def destroy(self):

View File

@ -28,6 +28,8 @@ class Cache:
class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
self.cpu_group = cpu_group
@ -40,6 +42,7 @@ class All2AllManagerBase:
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
@ -60,17 +63,21 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def set_num_sms(self, num_sms: int):
pass
def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def destroy(self):
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
module.quant_method.init_prepare_finalize(module)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.

View File

@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
# ep does not use pynccl
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states

View File

@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
dist.broadcast(input_, src=src, group=self.device_group)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states

View File

@ -871,17 +871,24 @@ class GroupCoordinator:
model)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
router_logits,
is_sequence_parallel)
else:
return hidden_states, router_logits
def combine(self, hidden_states) -> torch.Tensor:
def combine(self,
hidden_states,
is_sequence_parallel: bool = False) -> torch.Tensor:
if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states)
return self.device_communicator.combine(hidden_states,
is_sequence_parallel)
else:
return hidden_states

View File

@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int) -> list[int]:
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
sequence_parallel_size)
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
return sp_tokens.tolist()
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int,
max_num_tokens: int,
chunk_idx: int) -> list[int]:
dp_size = len(num_tokens_across_dp_cpu)
local_size = [-1] * dp_size
for i in range(dp_size):
dp_tokens = num_tokens_across_dp_cpu[i]
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
sequence_parallel_size)
sp_size = len(sp_tokens)
local_size = [-1] * sp_size
for i in range(sp_size):
# Take into account sharding if MoE activation is sequence parallel.
local_size[i] = min(max_num_tokens,
dp_tokens - (max_num_tokens * chunk_idx))
sp_tokens[i] - (max_num_tokens * chunk_idx))
if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done
return local_size
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
@dataclass
class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor
num_tokens_across_dp_cpu: torch.Tensor
# NOTE: local_sizes should only be set by the chunked_sizes context manager
local_sizes: Optional[list[int]] = None
@staticmethod
@ -98,6 +113,17 @@ class DPMetadata:
dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu()
# Get the cumulative tokens across sequence parallel ranks.
# In this case the input to the MoEs will be distributed w.r.t both
# DP and TP rank.
# When sp_size==1, this is just the cummulative num tokens across DP.
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
num_tokens_across_sp_cpu = (
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
num_tokens_across_sp_cpu = (
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@staticmethod
def should_ubatch_across_dp(
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
@ -147,10 +173,10 @@ class DPMetadata:
@staticmethod
def make(
parallel_config: ParallelConfig,
attn_metadata: Any,
num_tokens: int,
num_tokens_across_dp: Optional[torch.Tensor] = None
parallel_config: ParallelConfig,
attn_metadata: Any,
num_tokens: int,
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
) -> "DPMetadata":
assert parallel_config.data_parallel_size > 1
@ -167,18 +193,18 @@ class DPMetadata:
# If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
if num_tokens_across_dp is None:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
assert (num_tokens_across_dp_cpu is None
or num_tokens_across_dp_cpu[dp_rank] == batchsize
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
if num_tokens_across_dp_cpu is None:
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
num_tokens_across_dp)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
@contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
def chunked_sizes(self, sequence_parallel_size: int,
max_chunk_size_per_rank: int, chunk_idx: int):
"""
Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution.
@ -192,31 +218,40 @@ class DPMetadata:
`chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank.
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
to determine the chunk-wise split.
`self.local_sizes` is only valid inside the context.
Args:
sequence_parallel_size: When Attn is TP and MoE layers are EP,
we use SP between the layers to avoid
redundant ops. We need this value to
compute the chunked sizes.
max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for.
"""
cu_sizes = self.cu_tokens_across_dp_cpu
num_tokens_across_dp_cpu = [
(cu_sizes[i] -
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
for i in range(len(cu_sizes))
]
self.local_sizes = _compute_chunked_local_num_tokens(
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
self.num_tokens_across_dp_cpu, sequence_parallel_size,
max_chunk_size_per_rank, chunk_idx)
try:
yield self.local_sizes
finally:
self.local_sizes = None
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
self.num_tokens_across_dp_cpu, sequence_parallel_size)
try:
yield self.local_sizes
finally:
self.local_sizes = None
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
assert self.local_sizes is not None
return self.local_sizes

View File

@ -3,6 +3,7 @@
from abc import abstractmethod
from collections.abc import Iterable
from contextlib import nullcontext
from enum import Enum
from typing import Callable, Literal, Optional, Union, get_args, overload
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel:
self.sp_size = tp_size_
self.sp_size = tp_size_ if is_sequence_parallel else 1
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
with ctx.dp_metadata.chunked_sizes(self.sp_size,
moe_dp_chunk_size_per_rank,
chunk_idx):
process_chunk(chunk_start,
chunk_end,
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
else:
shared_output = None
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
ctx = get_forward_context()
sp_ctx = ctx.dp_metadata.sp_local_sizes(
self.sp_size) if ctx.dp_metadata else nullcontext()
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
with sp_ctx:
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel)
if shared_output is not None:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
final_hidden_states = (
shared_output,
final_hidden_states,
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states)
if shared_output is not None:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
final_hidden_states = (
shared_output,
final_hidden_states,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
states = self.maybe_all_reduce_tensor_model_parallel(states)
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states,
self.is_sequence_parallel)
return states
if (not self.is_sequence_parallel and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)):
states = self.maybe_all_reduce_tensor_model_parallel(
states)
if self.shared_experts is not None:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
return states
if self.shared_experts is not None:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(

View File

@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.config import QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -297,14 +297,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
Experts (MoE) Layer.
"""
def __init__(
self,
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config, prefix)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.mlp = AriaTextMoELayer(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")

View File

@ -32,7 +32,6 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
@ -56,8 +55,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -108,43 +107,6 @@ class DeepseekV2MLP(nn.Module):
return x
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
x = nn.functional.pad(x, (0, 0, 0, pad_len))
chunk = x.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(x, 0, start, chunk)
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
fake_impl=sequence_parallel_chunk_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class DeepseekV2MoE(nn.Module):
def __init__(
@ -166,20 +128,7 @@ class DeepseekV2MoE(nn.Module):
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
in ("deepep_high_throughput",
"deepep_low_latency")
and parallel_config.enable_expert_parallel
and self.tp_size > 1)
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
@ -278,8 +227,7 @@ class DeepseekV2MoE(nn.Module):
# TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here.
if self.is_sequence_parallel:
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
hidden_states)
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

View File

@ -29,10 +29,9 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
vllm_config: VllmConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_emb_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
prefix)
self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
def forward(
self,
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict({
str(idx):
ErnieMultiTokenPredictorLayer(
config,
vllm_config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)

View File

@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
class Glm4DecoderLayer(nn.Module):
def __init__(
self,
config: Glm4Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[Glm4Config] = None) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)

View File

@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
def __init__(
self,
config: GptOssConfig,
vllm_config: VllmConfig,
layer_idx: int,
quant_config: QuantizationConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.layer_idx = layer_idx
self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False,
has_bias=True,
activation="swigluoai")
activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
if self.is_sequence_parallel:
x = sequence_parallel_chunk(x)
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
x = x[:num_tokens]
return x
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
def __init__(
self,
config: GptOssConfig,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config,
prefix=f"{prefix}.attn",
cache_config=cache_config)
self.mlp = MLPBlock(config,
self.mlp = MLPBlock(vllm_config,
self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.cache_config = vllm_config.cache_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
self.config.num_hidden_layers,
lambda prefix: TransformerBlock(
self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
vllm_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",

View File

@ -29,12 +29,13 @@ from typing import Any, Optional
import torch
from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
is_sequence_parallel=False,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
self.is_sequence_parallel = is_sequence_parallel
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
num_tokens = orig_shape[0]
final_hidden_states = final_hidden_states[:num_tokens]
return final_hidden_states.view(orig_shape)
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size,
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GraniteMoeDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -68,6 +68,7 @@ class LlamaMLP(nn.Module):
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
disable_tp: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@ -75,6 +76,7 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
@ -83,6 +85,7 @@ class LlamaMLP(nn.Module):
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
@ -237,14 +240,16 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[LlamaConfig] = None) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
@ -335,7 +340,6 @@ class LlamaModel(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -357,10 +361,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:

View File

@ -28,7 +28,8 @@ from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
@ -59,13 +61,16 @@ class Llama4MoE(nn.Module):
router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32))
def __init__(self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size,
@ -82,6 +87,7 @@ class Llama4MoE(nn.Module):
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False,
disable_tp=self.is_sequence_parallel,
)
self.experts = SharedFusedMoE(
@ -96,9 +102,14 @@ class Llama4MoE(nn.Module):
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel,
)
def forward(self, hidden_states):
num_tokens = hidden_states.shape[0]
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
router_logits, _ = self.router(hidden_states)
shared_out, routed_out = self.experts(
@ -107,7 +118,10 @@ class Llama4MoE(nn.Module):
)
experts_out = routed_out + shared_out
if self.tp_size > 1:
if self.is_sequence_parallel:
experts_out = tensor_model_parallel_all_gather(experts_out, 0)
experts_out = experts_out[:num_tokens]
elif self.tp_size > 1:
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out)
@ -257,15 +271,16 @@ class Llama4Attention(nn.Module):
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[Llama4TextConfig] = None) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module):
self.layer_idx + 1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=f"{prefix}.feed_forward",
)
else:

View File

@ -68,9 +68,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([
Llama4DecoderLayer(
self.config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: LlamaConfig,
vllm_config: VllmConfig,
disable_input_layernorm: bool,
prefix: str = "",
config: Optional[LlamaConfig] = None,
) -> None:
super().__init__(config, prefix=prefix)
super().__init__(vllm_config, prefix=prefix, config=config)
# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
@ -64,9 +65,10 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
vllm_config,
i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -8,13 +8,11 @@ import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -28,17 +26,14 @@ logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[LlamaConfig] = None) -> None:
super().__init__(vllm_config, prefix=prefix, config=config)
config = config or vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
# override qkv
self.self_attn.qkv_proj = QKVParallelLinear(
@ -125,9 +120,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([
LlamaDecoderLayer(
config=self.config,
cache_config=current_vllm_config.cache_config,
current_vllm_config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
config=self.config,
)
])
if hasattr(self.config, "target_hidden_size"):

View File

@ -29,13 +29,13 @@ from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import Qwen3MoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
assert hidden_states.dim(
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
is_input_1d = hidden_states.dim() == 1
hidden_dim = hidden_states.shape[-1]
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
# return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else \
final_hidden_states
@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module):
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb)
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
@ -362,10 +373,8 @@ class Qwen3MoeModel(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config.get_text_config()
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module):
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb),
lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import (
@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states)
@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
vllm_config: VllmConfig,
layer_type: str,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.config = config
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock(
config=config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = Qwen3NextMLP(
@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
torch.zeros(
1,
1,
self.config.hidden_size,
config.hidden_size,
dtype=config.torch_dtype,
), )
self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros(
1,
1,
self.config.hidden_size,
config.hidden_size,
dtype=config.torch_dtype,
), )
@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config
speculative_config = vllm_config.speculative_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
def get_layer(prefix: str):
return Qwen3NextDecoderLayer(
config,
vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=prefix,
enable_eplb=enable_eplb,
)
self.start_layer, self.end_layer, self.layers = make_layers(

View File

@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
super().__init__()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
config,
vllm_config,
layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers))

View File

@ -13,11 +13,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
from vllm.utils import (cdiv, direct_register_custom_op,
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available)
logger = init_logger(__name__)
@ -743,3 +746,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
return hf_config.hidden_size
text_config = hf_config.get_text_config()
return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
y = nn.functional.pad(x, (0, 0, 0, pad_len))
else:
y = x
chunk = y.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(y, 0, start, chunk)
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk_impl",
op_func=sequence_parallel_chunk_impl,
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)