mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:55:00 +08:00
[BUGFIX] Add cu_tokens_across_sp to DPMetadata (#26457)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
47e66c24e2
commit
0d37450eb7
@ -161,6 +161,17 @@ class DPMetadata:
|
|||||||
assert self.local_sizes is not None
|
assert self.local_sizes is not None
|
||||||
return self.local_sizes
|
return self.local_sizes
|
||||||
|
|
||||||
|
# Get the cumulative tokens across sequence parallel ranks.
|
||||||
|
# In this case the input to the MoEs will be distributed w.r.t both
|
||||||
|
# DP and TP rank.
|
||||||
|
# When sp_size==1, this is just the cummulative num tokens across DP.
|
||||||
|
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||||
|
num_tokens_across_sp_cpu = (
|
||||||
|
self.num_tokens_across_dp_cpu - 1 + sp_size
|
||||||
|
) // sp_size
|
||||||
|
num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
|
||||||
|
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardContext:
|
class ForwardContext:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user