From d2be62378bffa4fe7bd857a524d0d2f6cc5a2c47 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Sep 2025 16:33:18 -0700 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/sampler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index eaf43d4335843..d03f6edcdec97 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -199,7 +199,7 @@ def _topk_logprobs_kernel( logits_ptr, logits_stride, topk_ids_ptr, - k, + topk, vocab_size, BLOCK_SIZE: tl.constexpr, PADDED_TOPK: tl.constexpr, @@ -224,12 +224,12 @@ def _topk_logprobs_kernel( lse = tl.log(se) k_offset = tl.arange(0, PADDED_TOPK) - k_mask = k_offset < k - topk_ids = tl.load(topk_ids_ptr + req_idx * k + k_offset, mask=k_mask) + k_mask = k_offset < topk + topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask) l = tl.load(row_ptr + topk_ids, mask=k_mask) o = l - max_val - lse - tl.store(output_ptr + req_idx * k + k_offset, o, mask=k_mask) + tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) def compute_logprobs(