From 0fb07c08d03b7905147556aeae9bab945aea607b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 18:08:33 +0000 Subject: [PATCH] Minor --- vllm/model_executor/models/jax/ops/flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jax/ops/flash_attn.py b/vllm/model_executor/models/jax/ops/flash_attn.py index e87a6a2785e8e..4985a61f186df 100644 --- a/vllm/model_executor/models/jax/ops/flash_attn.py +++ b/vllm/model_executor/models/jax/ops/flash_attn.py @@ -24,5 +24,5 @@ def flash_attn( min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]), min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]), min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]), - min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])) + min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])), ).transpose(0, 2, 1, 3)