[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Tyler Michael Smith 2025-09-27 10:22:28 -04:00 committed by yewentao256
parent 1e5e5d757e
commit e94aabe03d
23 changed files with 540 additions and 375 deletions

View File

@ -279,6 +279,24 @@ class ParallelConfig:
assert last_exc is not None assert last_exc is not None
raise last_exc raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod @staticmethod
def has_unfinished_dp(dp_group: ProcessGroup, def has_unfinished_dp(dp_group: ProcessGroup,
has_unfinished: bool) -> bool: has_unfinished: bool) -> bool:

View File

@ -6,7 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import get_dp_group from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx from vllm.utils import has_deep_ep, has_pplx
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
super().__init__(cpu_group) super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor, def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor): cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2) assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device, device=x.device,
dtype=x.dtype) dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ rank = self.rank if is_sequence_parallel else self.dp_rank
self.dp_rank - 1] world_size = (self.world_size
end = cu_tokens_across_dp_cpu[self.dp_rank] if is_sequence_parallel else self.dp_world_size)
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x) buffer[start:end, :].copy_(x)
for idx in range(self.dp_world_size): for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx] end = cu_tokens_across_sp_cpu[idx]
self.dp_group.broadcast(buffer[start:end, :], idx) get_ep_group().broadcast(buffer[start:end, :], idx)
return buffer return buffer
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
sizes = get_forward_context( hidden_states: torch.Tensor,
).dp_metadata.get_chunk_sizes_across_dp_rank() router_logits: torch.Tensor,
hidden_states, router_logits = get_dp_group().all_gatherv( is_sequence_parallel: bool = False
[hidden_states, router_logits], ) -> tuple[torch.Tensor, torch.Tensor]:
dim=0, sp_size = self.tp_group.world_size if is_sequence_parallel else 1
sizes=sizes, dp_metadata = get_forward_context().dp_metadata
) cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
sizes = get_forward_context( hidden_states: torch.Tensor,
).dp_metadata.get_chunk_sizes_across_dp_rank() is_sequence_parallel: bool = False) -> torch.Tensor:
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0, ep_rank = self.rank if is_sequence_parallel else self.dp_rank
sizes=sizes)
dp_metadata = get_forward_context().dp_metadata
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]
all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states return hidden_states
def destroy(self): def destroy(self):
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group): def __init__(self, cpu_group):
super().__init__(cpu_group) super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Gather hidden_states and router_logits from all dp ranks. Gather hidden_states and router_logits from all dp ranks.
""" """
sizes = get_forward_context( sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank() ).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits], [hidden_states, router_logits],
dim=0, dim=0,
sizes=sizes, sizes=sizes,
) )
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
""" """
Reduce-scatter hidden_states across all dp ranks. Reduce-scatter hidden_states across all dp ranks.
""" """
sizes = get_forward_context( sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank() ).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0, dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
sizes=sizes) hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states return hidden_states
def destroy(self): def destroy(self):
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
kwargs, pplx.AllToAll.internode kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode) if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.workspace_tensor = None self.workspace_tensor = None
self.prepare_workspace_tensor = None self.prepare_workspace_tensor = None
self.mapping = None self.mapping = None
self.initialized = False self.initialized = False

View File

@ -28,6 +28,8 @@ class Cache:
class All2AllManagerBase: class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group): def __init__(self, cpu_group):
self.cpu_group = cpu_group self.cpu_group = cpu_group
@ -40,6 +42,7 @@ class All2AllManagerBase:
# all2all lives in ep group, which is merged from dp and tp group # all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group() self.dp_group = get_dp_group()
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction # no self.ep_group since self.ep_group is still in construction
# when we create this object # when we create this object
self.dp_rank = self.dp_group.rank_in_group self.dp_rank = self.dp_group.rank_in_group
@ -60,17 +63,21 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def set_num_sms(self, num_sms: int): def set_num_sms(self, num_sms: int):
pass pass
def max_sms_used(self) -> Optional[int]: def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU return None # None means it could use the whole GPU
def dispatch(self, hidden_states: torch.Tensor, def combine(self,
router_logits: torch.Tensor): hidden_states: torch.Tensor,
raise NotImplementedError is_sequence_parallel: bool = False):
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
module.quant_method.init_prepare_finalize(module) module.quant_method.init_prepare_finalize(module)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
""" """
Combine the hidden states and router logits from the appropriate device. Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.

View File

@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
# ep does not use pynccl
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem self.use_torch_symm_mem = use_torch_symm_mem
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
SymmMemCommunicator) SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1: if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits) hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states) hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states return hidden_states

View File

@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits) hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states) hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states return hidden_states

View File

@ -871,17 +871,24 @@ class GroupCoordinator:
model) model)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states, return self.device_communicator.dispatch(hidden_states,
router_logits) router_logits,
is_sequence_parallel)
else: else:
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states) -> torch.Tensor: def combine(self,
hidden_states,
is_sequence_parallel: bool = False) -> torch.Tensor:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states) return self.device_communicator.combine(hidden_states,
is_sequence_parallel)
else: else:
return hidden_states return hidden_states

View File

@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
return BatchDescriptor(self.num_tokens, uniform_decode=False) return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int) -> list[int]:
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
sequence_parallel_size)
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
return sp_tokens.tolist()
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int,
max_num_tokens: int, max_num_tokens: int,
chunk_idx: int) -> list[int]: chunk_idx: int) -> list[int]:
dp_size = len(num_tokens_across_dp_cpu)
local_size = [-1] * dp_size sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
for i in range(dp_size): sequence_parallel_size)
dp_tokens = num_tokens_across_dp_cpu[i] sp_size = len(sp_tokens)
local_size = [-1] * sp_size
for i in range(sp_size):
# Take into account sharding if MoE activation is sequence parallel.
local_size[i] = min(max_num_tokens, local_size[i] = min(max_num_tokens,
dp_tokens - (max_num_tokens * chunk_idx)) sp_tokens[i] - (max_num_tokens * chunk_idx))
if local_size[i] <= 0: if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done local_size[i] = 1 # ensure lockstep even if done
return local_size return local_size
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
@dataclass @dataclass
class DPMetadata: class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor num_tokens_across_dp_cpu: torch.Tensor
# NOTE: local_sizes should only be set by the chunked_sizes context manager
local_sizes: Optional[list[int]] = None local_sizes: Optional[list[int]] = None
@staticmethod @staticmethod
@ -98,6 +113,17 @@ class DPMetadata:
dist.all_reduce(num_tokens_tensor, group=group) dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu() return num_tokens_tensor.cpu()
# Get the cumulative tokens across sequence parallel ranks.
# In this case the input to the MoEs will be distributed w.r.t both
# DP and TP rank.
# When sp_size==1, this is just the cummulative num tokens across DP.
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
num_tokens_across_sp_cpu = (
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
num_tokens_across_sp_cpu = (
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@staticmethod @staticmethod
def should_ubatch_across_dp( def should_ubatch_across_dp(
should_ubatch: bool, orig_num_tokens_per_ubatch: int, should_ubatch: bool, orig_num_tokens_per_ubatch: int,
@ -147,10 +173,10 @@ class DPMetadata:
@staticmethod @staticmethod
def make( def make(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
attn_metadata: Any, attn_metadata: Any,
num_tokens: int, num_tokens: int,
num_tokens_across_dp: Optional[torch.Tensor] = None num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
) -> "DPMetadata": ) -> "DPMetadata":
assert parallel_config.data_parallel_size > 1 assert parallel_config.data_parallel_size > 1
@ -167,18 +193,18 @@ class DPMetadata:
# If num_tokens_across_dp is None, it will be computed by all_reduce # If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank] assert (num_tokens_across_dp_cpu is None
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}" or num_tokens_across_dp_cpu[dp_rank] == batchsize
if num_tokens_across_dp is None: ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
num_tokens_across_dp = DPMetadata.num_tokens_across_dp( if num_tokens_across_dp_cpu is None:
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank) batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
num_tokens_across_dp)
@contextmanager @contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): def chunked_sizes(self, sequence_parallel_size: int,
max_chunk_size_per_rank: int, chunk_idx: int):
""" """
Context manager to compute and temporarily set the per-rank local token Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution. sizes for a specific chunk during chunked forward execution.
@ -192,31 +218,40 @@ class DPMetadata:
`chunk_idx`, this context manager sets `self.local_sizes` to the number `chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank. of tokens to process in that chunk on each rank.
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
to determine the chunk-wise split.
`self.local_sizes` is only valid inside the context. `self.local_sizes` is only valid inside the context.
Args: Args:
sequence_parallel_size: When Attn is TP and MoE layers are EP,
we use SP between the layers to avoid
redundant ops. We need this value to
compute the chunked sizes.
max_chunk_size_per_rank: The max number of tokens each rank is max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk. allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for. chunk_idx: The index of the chunk to compute sizes for.
""" """
cu_sizes = self.cu_tokens_across_dp_cpu
num_tokens_across_dp_cpu = [
(cu_sizes[i] -
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
for i in range(len(cu_sizes))
]
self.local_sizes = _compute_chunked_local_num_tokens( self.local_sizes = _compute_chunked_local_num_tokens(
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx) self.num_tokens_across_dp_cpu, sequence_parallel_size,
max_chunk_size_per_rank, chunk_idx)
try:
yield self.local_sizes
finally:
self.local_sizes = None
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
self.num_tokens_across_dp_cpu, sequence_parallel_size)
try: try:
yield self.local_sizes yield self.local_sizes
finally: finally:
self.local_sizes = None self.local_sizes = None
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
assert self.local_sizes is not None
return self.local_sizes return self.local_sizes

View File

@ -3,6 +3,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import Callable, Literal, Optional, Union, get_args, overload from typing import Callable, Literal, Optional, Union, get_args, overload
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
if dp_size is not None else get_dp_group().world_size) if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel: self.sp_size = tp_size_ if is_sequence_parallel else 1
self.sp_size = tp_size_
self.moe_parallel_config: FusedMoEParallelConfig = ( self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make( FusedMoEParallelConfig.make(
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
# clamp start and end # clamp start and end
chunk_start = min(chunk_start, num_tokens - 1) chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens) chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, with ctx.dp_metadata.chunked_sizes(self.sp_size,
moe_dp_chunk_size_per_rank,
chunk_idx): chunk_idx):
process_chunk(chunk_start, process_chunk(chunk_start,
chunk_end, chunk_end,
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
else: else:
shared_output = None shared_output = None
if do_naive_dispatch_combine: ctx = get_forward_context()
hidden_states, router_logits = get_ep_group().dispatch( sp_ctx = ctx.dp_metadata.sp_local_sizes(
hidden_states, router_logits) self.sp_size) if ctx.dp_metadata else nullcontext()
# Matrix multiply. with sp_ctx:
final_hidden_states = self.quant_method.apply( if do_naive_dispatch_combine:
layer=self, hidden_states, router_logits = get_ep_group().dispatch(
x=hidden_states, hidden_states, router_logits, self.is_sequence_parallel)
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if shared_output is not None: # Matrix multiply.
assert not isinstance(final_hidden_states, tuple) final_hidden_states = self.quant_method.apply(
assert self.shared_experts is not None layer=self,
final_hidden_states = ( x=hidden_states,
shared_output, router_logits=router_logits,
final_hidden_states, top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
) )
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(states: torch.Tensor, if shared_output is not None:
do_combine: bool = True) -> torch.Tensor: assert not isinstance(final_hidden_states, tuple)
if do_naive_dispatch_combine and do_combine: assert self.shared_experts is not None
states = get_ep_group().combine(states) final_hidden_states = (
shared_output,
final_hidden_states,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): def reduce_output(states: torch.Tensor,
states = self.maybe_all_reduce_tensor_model_parallel(states) do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states,
self.is_sequence_parallel)
return states if (not self.is_sequence_parallel and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)):
states = self.maybe_all_reduce_tensor_model_parallel(
states)
if self.shared_experts is not None: return states
return (
reduce_output(final_hidden_states[0], do_combine=False), if self.shared_experts is not None:
reduce_output(final_hidden_states[1]), return (
) reduce_output(final_hidden_states[0], do_combine=False),
elif self.zero_expert_num is not None and self.zero_expert_num > 0: reduce_output(final_hidden_states[1]),
assert isinstance(final_hidden_states, torch.Tensor) )
return reduce_output(final_hidden_states) + zero_expert_result elif self.zero_expert_num is not None and self.zero_expert_num > 0:
else: assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(

View File

@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor from transformers.models.aria.processing_aria import AriaProcessor
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.config import QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -297,14 +297,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
Experts (MoE) Layer. Experts (MoE) Layer.
""" """
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
self, super().__init__(vllm_config, prefix)
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None, config = vllm_config.model_config.hf_config
quant_config: Optional[QuantizationConfig] = None, quant_config = vllm_config.quant_config
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
self.mlp = AriaTextMoELayer(config, self.mlp = AriaTextMoELayer(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")

View File

@ -32,7 +32,6 @@ import torch
from torch import nn from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
@ -56,8 +55,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -108,43 +107,6 @@ class DeepseekV2MLP(nn.Module):
return x return x
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
x = nn.functional.pad(x, (0, 0, 0, pad_len))
chunk = x.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(x, 0, start, chunk)
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
fake_impl=sequence_parallel_chunk_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
@ -166,20 +128,7 @@ class DeepseekV2MoE(nn.Module):
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
# The all_reduce at the end of attention (during o_proj) means that self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
in ("deepep_high_throughput",
"deepep_low_latency")
and parallel_config.enable_expert_parallel
and self.tp_size > 1)
if config.hidden_act != "silu": if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. " raise ValueError(f"Unsupported activation: {config.hidden_act}. "
@ -278,8 +227,7 @@ class DeepseekV2MoE(nn.Module):
# TODO: We can replace the all_reduce at the end of attn with a # TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here. # reduce_scatter instead of chunking here.
if self.is_sequence_parallel: if self.is_sequence_parallel:
hidden_states = torch.ops.vllm.sequence_parallel_chunk( hidden_states = sequence_parallel_chunk(hidden_states)
hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)

View File

@ -29,10 +29,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
prefix: str, prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_emb_norm = RMSNorm(config.hidden_size, self.mtp_emb_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
prefix)
def forward( def forward(
self, self,
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict({ self.layers = torch.nn.ModuleDict({
str(idx): str(idx):
ErnieMultiTokenPredictorLayer( ErnieMultiTokenPredictorLayer(
config, vllm_config,
f"{prefix}.layers.{idx}", f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
) )
for idx in range(self.mtp_start_layer_idx, for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers) self.mtp_start_layer_idx + self.num_mtp_layers)

View File

@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
class Glm4DecoderLayer(nn.Module): class Glm4DecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: Glm4Config, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[Glm4Config] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000) rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)

View File

@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
config: GptOssConfig, vllm_config: VllmConfig,
layer_idx: int, layer_idx: int,
quant_config: QuantizationConfig,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok self.experts_per_token = config.num_experts_per_tok
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
apply_router_weight_on_input=False, apply_router_weight_on_input=False,
has_bias=True, has_bias=True,
activation="swigluoai") activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
if self.is_sequence_parallel:
x = sequence_parallel_chunk(x)
g = self.router(x) g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g) x = self.experts(hidden_states=x, router_logits=g)
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
x = x[:num_tokens]
return x return x
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
config: GptOssConfig, vllm_config: VllmConfig,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config, self.attn = OAIAttention(config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
cache_config=cache_config) cache_config=cache_config)
self.mlp = MLPBlock(config, self.mlp = MLPBlock(vllm_config,
self.layer_idx, self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
): ):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.cache_config = vllm_config.cache_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding( self.embedding = VocabParallelEmbedding(
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
self.config.num_hidden_layers, self.config.num_hidden_layers,
lambda prefix: TransformerBlock( lambda prefix: TransformerBlock(
self.config, vllm_config,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=prefix, prefix=prefix,
), ),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",

View File

@ -29,12 +29,13 @@ from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
is_sequence_parallel=False,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.is_sequence_parallel = is_sequence_parallel
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size, self.gate = ReplicatedLinear(hidden_size,
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
prefix=f"{prefix}.experts") prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
num_tokens = orig_shape[0]
final_hidden_states = final_hidden_states[:num_tokens]
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: GraniteMoeConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
prefix=f"{prefix}.block_sparse_moe") prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: GraniteMoeDecoderLayer( lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -68,6 +68,7 @@ class LlamaMLP(nn.Module):
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
reduce_results: bool = True, reduce_results: bool = True,
disable_tp: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
@ -75,6 +76,7 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2, output_sizes=[intermediate_size] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
@ -83,6 +85,7 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
if hidden_act != "silu": if hidden_act != "silu":
@ -237,14 +240,16 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: LlamaConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[LlamaConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
@ -335,7 +340,6 @@ class LlamaModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
@ -357,10 +361,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: layer_type(config=config, lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:

View File

@ -28,7 +28,8 @@ from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
@ -59,13 +61,16 @@ class Llama4MoE(nn.Module):
router_scores = torch.sigmoid(router_scores.float()) router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32)) return (router_scores, router_indices.to(torch.int32))
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
intermediate_size_moe = config.intermediate_size intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size, self.router = ReplicatedLinear(config.hidden_size,
@ -82,6 +87,7 @@ class Llama4MoE(nn.Module):
bias=False, bias=False,
prefix=f"{prefix}.shared_expert", prefix=f"{prefix}.shared_expert",
reduce_results=False, reduce_results=False,
disable_tp=self.is_sequence_parallel,
) )
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
@ -96,9 +102,14 @@ class Llama4MoE(nn.Module):
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
num_tokens = hidden_states.shape[0]
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
shared_out, routed_out = self.experts( shared_out, routed_out = self.experts(
@ -107,7 +118,10 @@ class Llama4MoE(nn.Module):
) )
experts_out = routed_out + shared_out experts_out = routed_out + shared_out
if self.tp_size > 1: if self.is_sequence_parallel:
experts_out = tensor_model_parallel_all_gather(experts_out, 0)
experts_out = experts_out[:num_tokens]
elif self.tp_size > 1:
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out) experts_out)
@ -257,15 +271,16 @@ class Llama4Attention(nn.Module):
class Llama4DecoderLayer(nn.Module): class Llama4DecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: Llama4TextConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[Llama4TextConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module):
self.layer_idx + 1) % config.interleave_moe_layer_step == 0 self.layer_idx + 1) % config.interleave_moe_layer_step == 0
if is_moe_layer: if is_moe_layer:
self.feed_forward = Llama4MoE( self.feed_forward = Llama4MoE(
config=config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward", prefix=f"{prefix}.feed_forward",
) )
else: else:

View File

@ -68,9 +68,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Llama4DecoderLayer( Llama4DecoderLayer(
self.config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers) ) for i in range(self.config.num_hidden_layers)
]) ])
self.fc = torch.nn.Linear(self.config.hidden_size * 2, self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__( def __init__(
self, self,
config: LlamaConfig, vllm_config: VllmConfig,
disable_input_layernorm: bool, disable_input_layernorm: bool,
prefix: str = "", prefix: str = "",
config: Optional[LlamaConfig] = None,
) -> None: ) -> None:
super().__init__(config, prefix=prefix) super().__init__(vllm_config, prefix=prefix, config=config)
# Skip the input_layernorm # Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
@ -64,9 +65,10 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer( LlamaDecoderLayer(
self.config, vllm_config,
i == 0, i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers) ) for i in range(self.config.num_hidden_layers)
]) ])
self.fc = torch.nn.Linear(self.config.hidden_size * 2, self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -8,13 +8,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -28,17 +26,14 @@ logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer): class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: LlamaConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[LlamaConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None, super().__init__(vllm_config, prefix=prefix, config=config)
prefix: str = "",
) -> None: config = config or vllm_config.model_config.hf_config
super().__init__(config, quant_config = vllm_config.quant_config
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
# override qkv # override qkv
self.self_attn.qkv_proj = QKVParallelLinear( self.self_attn.qkv_proj = QKVParallelLinear(
@ -125,9 +120,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer( LlamaDecoderLayer(
config=self.config, current_vllm_config,
cache_config=current_vllm_config.cache_config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
config=self.config,
) )
]) ])
if hasattr(self.config, "target_hidden_size"): if hasattr(self.config, "target_hidden_size"):

View File

@ -29,13 +29,13 @@ from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Qwen3MoeConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3MoeConfig, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Load balancing settings. # Load balancing settings.
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts self.n_redundant_experts = eplb_config.num_redundant_experts
@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
assert hidden_states.dim( assert hidden_states.dim(
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
is_input_1d = hidden_states.dim() == 1 is_input_1d = hidden_states.dim() == 1
hidden_dim = hidden_states.shape[-1] num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
# return to 1d if input is 1d # return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else \ return final_hidden_states.squeeze(0) if is_input_1d else \
final_hidden_states final_hidden_states
@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module):
class Qwen3MoeDecoderLayer(nn.Module): class Qwen3MoeDecoderLayer(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
self,
config: Qwen3MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
if (layer_idx not in mlp_only_layers) and ( if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0): (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config, self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
quant_config=quant_config, prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb)
else: else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
@ -362,10 +373,8 @@ class Qwen3MoeModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config.get_text_config() config = vllm_config.model_config.hf_config.get_text_config()
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts self.num_redundant_experts = eplb_config.num_redundant_experts
@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module):
prefix=f"{prefix}.embed_tokens") prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(config=config, lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
cache_config=cache_config, prefix=prefix),
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config) VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_ep_group, get_pp_group, from vllm.distributed import (divide, get_ep_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import ( from vllm.model_executor.layers.fla.ops import (
@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader) default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module): class Qwen3NextSparseMoeBlock(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
# Load balancing settings. # Load balancing settings.
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts self.n_redundant_experts = eplb_config.num_redundant_experts
@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1] num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
shared_output = None shared_output = None
if self.shared_expert is not None: if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states) shared_output = self.shared_expert(hidden_states)
@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states) final_hidden_states)
@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3NextConfig, vllm_config: VllmConfig,
layer_type: str, layer_type: str,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
config.num_experts > 0 and config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0): (self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock( self.mlp = Qwen3NextSparseMoeBlock(
config=config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
) )
else: else:
self.mlp = Qwen3NextMLP( self.mlp = Qwen3NextMLP(
@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
torch.zeros( torch.zeros(
1, 1,
1, 1,
self.config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.torch_dtype,
), ) ), )
self.ffn_layer_scale = torch.nn.Parameter( self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
1, 1,
1, 1,
self.config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.torch_dtype,
), ) ), )
@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
super().__init__() super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
speculative_config = vllm_config.speculative_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts self.num_redundant_experts = eplb_config.num_redundant_experts
@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
def get_layer(prefix: str): def get_layer(prefix: str):
return Qwen3NextDecoderLayer( return Qwen3NextDecoderLayer(
config, vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)], layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=prefix, prefix=prefix,
enable_eplb=enable_eplb,
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(

View File

@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
super().__init__() super().__init__()
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config config: Qwen3NextConfig = model_config.hf_config
@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer( Qwen3NextDecoderLayer(
config, vllm_config,
layer_type="full_attention", layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{idx}', prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers)) ) for idx in range(self.num_mtp_layers))

View File

@ -13,11 +13,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, from vllm.utils import (cdiv, direct_register_custom_op,
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available) is_uva_available)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -743,3 +746,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
return hf_config.hidden_size return hf_config.hidden_size
text_config = hf_config.get_text_config() text_config = hf_config.get_text_config()
return text_config.hidden_size return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
y = nn.functional.pad(x, (0, 0, 0, pad_len))
else:
y = x
chunk = y.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(y, 0, start, chunk)
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk_impl",
op_func=sequence_parallel_chunk_impl,
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)