diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index bbcd560ad56e3..5d2ba154ae018 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -97,7 +97,6 @@ def triton_reshape_and_cache_flash( k_scale: torch.Tensor, # float32 v_scale: torch.Tensor, # float32 ): - num_tokens = key.shape[0] num_heads = key.shape[1] head_size = key.shape[2] block_size = key_cache.shape[1] @@ -155,7 +154,10 @@ def triton_reshape_and_cache_flash( # TODO(ngl): maybe replace with static launch grid to avoid overhead if # using cudagraphs - grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) + grid = lambda meta: ( + slot_mapping.shape[0], + triton.cdiv(n, meta["TILE_SIZE"]), + ) reshape_and_cache_kernel_flash[grid]( key_ptr=key,