From 6efb195a6e02468001cc2a1ef818d55426da97c0 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Tue, 1 Apr 2025 22:06:44 -0400 Subject: [PATCH] [V1] Fix: make sure `k_index` is int64 for `apply_top_k_only` (#15907) Signed-off-by: Brayden Zhong --- vllm/v1/sample/ops/topk_topp_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 5dfcae08b170c..d4bc23364c574 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -200,7 +200,7 @@ def apply_top_k_only( # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). k_index = k.sub_(1).unsqueeze(1) - top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index) + top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) logits.masked_fill_(logits < top_k_mask, -float("inf"))