mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 20:35:40 +08:00
[Bugfix] Remove contiguous output req for context parallel MLA (#25414)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
eea1783989
commit
78237e43bf
@ -134,6 +134,5 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
|||||||
cp_attn_lse = cp_attn_lse.contiguous()
|
cp_attn_lse = cp_attn_lse.contiguous()
|
||||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||||
assert out.is_contiguous()
|
|
||||||
out = cp_group.reduce_scatter(out, dim=1)
|
out = cp_group.reduce_scatter(out, dim=1)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user