From 15ae8e0784d3889c6aa2c487ca00df4e3fde6f44 Mon Sep 17 00:00:00 2001 From: rasmith Date: Fri, 14 Nov 2025 00:34:01 -0600 Subject: [PATCH] [Bugfix][CI/Test][Spec Decode] Fix illegal memory access in offline_inference/spec_decode.py (Issue 27619) (#28432) Signed-off-by: Randall Smith Co-authored-by: Randall Smith Co-authored-by: TJian --- vllm/attention/ops/triton_reshape_and_cache_flash.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index bbcd560ad56e3..5d2ba154ae018 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -97,7 +97,6 @@ def triton_reshape_and_cache_flash( k_scale: torch.Tensor, # float32 v_scale: torch.Tensor, # float32 ): - num_tokens = key.shape[0] num_heads = key.shape[1] head_size = key.shape[2] block_size = key_cache.shape[1] @@ -155,7 +154,10 @@ def triton_reshape_and_cache_flash( # TODO(ngl): maybe replace with static launch grid to avoid overhead if # using cudagraphs - grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) + grid = lambda meta: ( + slot_mapping.shape[0], + triton.cdiv(n, meta["TILE_SIZE"]), + ) reshape_and_cache_kernel_flash[grid]( key_ptr=key,