From 8d0a01a5f2b53794e4bc6b734d7b63cb8a9b7d7d Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Mon, 21 Jul 2025 13:47:47 -0700 Subject: [PATCH] [v1][sampler] Inplace logprobs comparison to get the token rank (#21283) Signed-off-by: Lu Fang --- vllm/v1/sample/ops/logprobs.py | 24 ++++++++++++++++++++++++ vllm/v1/sample/sampler.py | 3 ++- 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 vllm/v1/sample/ops/logprobs.py diff --git a/vllm/v1/sample/ops/logprobs.py b/vllm/v1/sample/ops/logprobs.py new file mode 100644 index 0000000000000..a4d65485140ec --- /dev/null +++ b/vllm/v1/sample/ops/logprobs.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Some utilities for logprobs, including logits.""" + +import torch + + +@torch.compile(dynamic=True) +def batched_count_greater_than(x: torch.Tensor, + values: torch.Tensor) -> torch.Tensor: + """ + Counts elements in each row of x that are greater than the corresponding + value in values. Use torch.compile to generate an optimized kernel for + this function. otherwise, it will create additional copies of the input + tensors and cause memory issues. + + Args: + x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements). + values (torch.Tensor): A 2D tensor of shape (batch_size, 1). + + Returns: + torch.Tensor: A 1D tensor of shape (batch_size,) with the counts. + """ + return (x >= values).sum(-1) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index e79e4451a3a3f..fa078e6287685 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -9,6 +9,7 @@ from vllm.utils import is_pin_memory_available from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words +from vllm.v1.sample.ops.logprobs import batched_count_greater_than from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler @@ -174,7 +175,7 @@ class Sampler(nn.Module): token_logprobs = logprobs.gather(-1, token_ids) # Compute the ranks of the actual token. - token_ranks = (logprobs >= token_logprobs).sum(-1) + token_ranks = batched_count_greater_than(logprobs, token_logprobs) # Concatenate together with the topk. indices = torch.cat((token_ids, topk_indices), dim=1)