From dd66fd2b01e1195b7ccc8ffcd4b5d49ff1946a56 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Tue, 28 Jan 2025 14:11:05 +0800 Subject: [PATCH] [CI] fix pre-commit error (#12494) Signed-off-by: Mengqing Cao --- vllm/attention/ops/nki_flash_attn.py | 37 +++++++++++++++++--------- vllm/spec_decode/spec_decode_worker.py | 8 +++--- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py index b9765b0f0283d..9de4ef7f5a140 100644 --- a/vllm/attention/ops/nki_flash_attn.py +++ b/vllm/attention/ops/nki_flash_attn.py @@ -106,11 +106,12 @@ def _flash_attention_core( assert (continuous_batching_mask is not None), "continuous_batching_mask input is required." if continuous_batching_mask is not None: - assert (logit_bias_tile is - None), "continuous_batching_mask does not support logit_bias!" + assert ( + logit_bias_tile + is None), "continuous_batching_mask does not support logit_bias!" # mask are used to only apply computation to the lower half of the matrix, - # which reduce the arthimetic intensity by half + # which reduce the arithmetic intensity by half forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None) @@ -468,9 +469,11 @@ def flash_paged_attention( block_in_partition) loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :, head_id, :]) - cur_v_tile[partition_idx, - nl.ds(block_in_partition * - block_size, block_size), :, ] = loaded_v + cur_v_tile[ + partition_idx, + nl.ds(block_in_partition * block_size, block_size), + :, + ] = loaded_v cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=mask.dtype) @@ -601,20 +604,30 @@ def flash_paged_attention( ) nl.store( - o[batch_id, head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), :, ], + o[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + :, + ], out, ) # maximum and summation statistics if return_debug_tensors: nl.store( - hbm_m_buffer[batch_id, head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), ], + hbm_m_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], m_buffer[i, i_q_h, :, :], ) nl.store( - hbm_l_buffer[batch_id, head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), ], + hbm_l_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], l_buffer[:, i, i_q_h], ) nl.store( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index af1c4dfcebbc0..8d6d05cbaea75 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -870,10 +870,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b # Drop non-terminal prefill chunks hidden states. - hidden_states = hidden_states[ - accepted_index != VLLM_INVALID_TOKEN_ID] - accepted_index = accepted_index[ - accepted_index != VLLM_INVALID_TOKEN_ID] + hidden_states = hidden_states[accepted_index != + VLLM_INVALID_TOKEN_ID] + accepted_index = accepted_index[accepted_index != + VLLM_INVALID_TOKEN_ID] assert len(accepted_index) == hidden_states.shape[0] == len( terminal_metadata) index = accepted_index[:, None, None].expand(-1, 1,