mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 16:57:53 +08:00
maybe fix
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
da63274d9f
commit
cd3ea013d6
@ -167,6 +167,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
MAX_HEADS = 128
|
MAX_HEADS = 128
|
||||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||||
|
if H < MAX_HEADS:
|
||||||
|
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||||
|
q_nope_padded[:, :H] = q_nope
|
||||||
|
q_nope = q_nope_padded
|
||||||
|
|
||||||
|
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||||
|
q_pe_padded[:, :H] = q_pe
|
||||||
|
q_pe = q_pe_padded
|
||||||
|
|
||||||
assert len(page_table.shape) == 2
|
assert len(page_table.shape) == 2
|
||||||
B_block_table, block_num = page_table.shape
|
B_block_table, block_num = page_table.shape
|
||||||
@ -209,8 +217,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
if H < MAX_HEADS:
|
if H < MAX_HEADS:
|
||||||
# Extract the subsets of the outputs
|
# Extract the subsets of the outputs
|
||||||
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
lse = lse[:, :H].contiguous(
|
||||||
out = out[:, :H]
|
) if self.need_to_return_lse_for_decode else lse
|
||||||
|
out = out[:, :H].contiguous()
|
||||||
|
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user