[BUGFIX] Add cu_tokens_across_sp to DPMetadata (#26457)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-10-09 10:13:56 -07:00 committed by GitHub
parent 47e66c24e2
commit 0d37450eb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -161,6 +161,17 @@ class DPMetadata:
assert self.local_sizes is not None
return self.local_sizes
# Get the cumulative tokens across sequence parallel ranks.
# In this case the input to the MoEs will be distributed w.r.t both
# DP and TP rank.
# When sp_size==1, this is just the cummulative num tokens across DP.
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
num_tokens_across_sp_cpu = (
self.num_tokens_across_dp_cpu - 1 + sp_size
) // sp_size
num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@dataclass
class ForwardContext: