mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 20:31:23 +08:00
[CI] fix pre-commit error (#12494)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
0f465ab533
commit
dd66fd2b01
@ -106,11 +106,12 @@ def _flash_attention_core(
|
|||||||
assert (continuous_batching_mask
|
assert (continuous_batching_mask
|
||||||
is not None), "continuous_batching_mask input is required."
|
is not None), "continuous_batching_mask input is required."
|
||||||
if continuous_batching_mask is not None:
|
if continuous_batching_mask is not None:
|
||||||
assert (logit_bias_tile is
|
assert (
|
||||||
None), "continuous_batching_mask does not support logit_bias!"
|
logit_bias_tile
|
||||||
|
is None), "continuous_batching_mask does not support logit_bias!"
|
||||||
|
|
||||||
# mask are used to only apply computation to the lower half of the matrix,
|
# mask are used to only apply computation to the lower half of the matrix,
|
||||||
# which reduce the arthimetic intensity by half
|
# which reduce the arithmetic intensity by half
|
||||||
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
|
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
|
||||||
LARGE_TILE_SZ if use_causal_mask else None)
|
LARGE_TILE_SZ if use_causal_mask else None)
|
||||||
|
|
||||||
@ -468,9 +469,11 @@ def flash_paged_attention(
|
|||||||
block_in_partition)
|
block_in_partition)
|
||||||
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
|
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
|
||||||
head_id, :])
|
head_id, :])
|
||||||
cur_v_tile[partition_idx,
|
cur_v_tile[
|
||||||
nl.ds(block_in_partition *
|
partition_idx,
|
||||||
block_size, block_size), :, ] = loaded_v
|
nl.ds(block_in_partition * block_size, block_size),
|
||||||
|
:,
|
||||||
|
] = loaded_v
|
||||||
|
|
||||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||||
dtype=mask.dtype)
|
dtype=mask.dtype)
|
||||||
@ -601,20 +604,30 @@ def flash_paged_attention(
|
|||||||
)
|
)
|
||||||
|
|
||||||
nl.store(
|
nl.store(
|
||||||
o[batch_id, head_id * q_h_per_k_h + i_q_h,
|
o[
|
||||||
nl.ds(i * B_P_SIZE, B_P_SIZE), :, ],
|
batch_id,
|
||||||
|
head_id * q_h_per_k_h + i_q_h,
|
||||||
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||||
|
:,
|
||||||
|
],
|
||||||
out,
|
out,
|
||||||
)
|
)
|
||||||
# maximum and summation statistics
|
# maximum and summation statistics
|
||||||
if return_debug_tensors:
|
if return_debug_tensors:
|
||||||
nl.store(
|
nl.store(
|
||||||
hbm_m_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
|
hbm_m_buffer[
|
||||||
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
|
batch_id,
|
||||||
|
head_id * q_h_per_k_h + i_q_h,
|
||||||
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||||
|
],
|
||||||
m_buffer[i, i_q_h, :, :],
|
m_buffer[i, i_q_h, :, :],
|
||||||
)
|
)
|
||||||
nl.store(
|
nl.store(
|
||||||
hbm_l_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
|
hbm_l_buffer[
|
||||||
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
|
batch_id,
|
||||||
|
head_id * q_h_per_k_h + i_q_h,
|
||||||
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||||
|
],
|
||||||
l_buffer[:, i, i_q_h],
|
l_buffer[:, i, i_q_h],
|
||||||
)
|
)
|
||||||
nl.store(
|
nl.store(
|
||||||
|
|||||||
@ -870,10 +870,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
||||||
# Drop non-terminal prefill chunks hidden states.
|
# Drop non-terminal prefill chunks hidden states.
|
||||||
hidden_states = hidden_states[
|
hidden_states = hidden_states[accepted_index !=
|
||||||
accepted_index != VLLM_INVALID_TOKEN_ID]
|
VLLM_INVALID_TOKEN_ID]
|
||||||
accepted_index = accepted_index[
|
accepted_index = accepted_index[accepted_index !=
|
||||||
accepted_index != VLLM_INVALID_TOKEN_ID]
|
VLLM_INVALID_TOKEN_ID]
|
||||||
assert len(accepted_index) == hidden_states.shape[0] == len(
|
assert len(accepted_index) == hidden_states.shape[0] == len(
|
||||||
terminal_metadata)
|
terminal_metadata)
|
||||||
index = accepted_index[:, None, None].expand(-1, 1,
|
index = accepted_index[:, None, None].expand(-1, 1,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user