diff --git a/vllm/model_executor/models/jax/ops/flash_attn.py b/vllm/model_executor/models/jax/ops/flash_attn.py index e87a6a2785e8e..4985a61f186df 100644 --- a/vllm/model_executor/models/jax/ops/flash_attn.py +++ b/vllm/model_executor/models/jax/ops/flash_attn.py @@ -24,5 +24,5 @@ def flash_attn( min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]), min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]), min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]), - min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])) + min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])), ).transpose(0, 2, 1, 3)