diff --git a/vllm/model_executor/models/jax/ops/flash_attn.py b/vllm/model_executor/models/jax/ops/flash_attn.py index 4985a61f186df..1b9f949d1752a 100644 --- a/vllm/model_executor/models/jax/ops/flash_attn.py +++ b/vllm/model_executor/models/jax/ops/flash_attn.py @@ -8,12 +8,13 @@ _DEFAULT_BLOCK_SIZES = { "block_b": 2, } + def flash_attn( q: jax.Array, # [batch, seq_len, num_heads, head_size] k: jax.Array, # [batch, seq_len, num_heads, head_size] v: jax.Array, # [batch, seq_len, num_heads, head_size] sm_scale: float, -) -> jax.Array: # [batch, seq_len, num_heads, head_size] +) -> jax.Array: # [batch, seq_len, num_heads, head_size] return flash_attention( q.transpose(0, 2, 1, 3), k.transpose(0, 2, 1, 3), @@ -21,8 +22,8 @@ def flash_attn( causal=True, sm_scale=sm_scale, block_sizes=BlockSizes( - min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]), - min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]), - min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]), - min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])), + min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]), + min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]), + min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]), + min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])), ).transpose(0, 2, 1, 3)