[Bugfix][CI/Test][Spec Decode] Fix illegal memory access in offline_inference/spec_decode.py (Issue 27619) (#28432)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
rasmith 2025-11-14 00:34:01 -06:00 committed by GitHub
parent 0b25498990
commit 15ae8e0784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -97,7 +97,6 @@ def triton_reshape_and_cache_flash(
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
num_tokens = key.shape[0]
num_heads = key.shape[1]
head_size = key.shape[2]
block_size = key_cache.shape[1]
@ -155,7 +154,10 @@ def triton_reshape_and_cache_flash(
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
grid = lambda meta: (
slot_mapping.shape[0],
triton.cdiv(n, meta["TILE_SIZE"]),
)
reshape_and_cache_kernel_flash[grid](
key_ptr=key,