diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a6a1e36bfe953..09da1398b0309 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -161,6 +161,17 @@ class DPMetadata: assert self.local_sizes is not None return self.local_sizes + # Get the cumulative tokens across sequence parallel ranks. + # In this case the input to the MoEs will be distributed w.r.t both + # DP and TP rank. + # When sp_size==1, this is just the cummulative num tokens across DP. + def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: + num_tokens_across_sp_cpu = ( + self.num_tokens_across_dp_cpu - 1 + sp_size + ) // sp_size + num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size) + return torch.cumsum(num_tokens_across_sp_cpu, dim=0) + @dataclass class ForwardContext: