mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[BUGFIX] Add cu_tokens_across_sp to DPMetadata (#26457)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
47e66c24e2
commit
0d37450eb7
@ -161,6 +161,17 @@ class DPMetadata:
|
||||
assert self.local_sizes is not None
|
||||
return self.local_sizes
|
||||
|
||||
# Get the cumulative tokens across sequence parallel ranks.
|
||||
# In this case the input to the MoEs will be distributed w.r.t both
|
||||
# DP and TP rank.
|
||||
# When sp_size==1, this is just the cummulative num tokens across DP.
|
||||
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||
num_tokens_across_sp_cpu = (
|
||||
self.num_tokens_across_dp_cpu - 1 + sp_size
|
||||
) // sp_size
|
||||
num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
|
||||
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user