mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 12:36:04 +08:00
[v1][sampler] Inplace logprobs comparison to get the token rank (#21283)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
parent
0ec82edda5
commit
8d0a01a5f2
24
vllm/v1/sample/ops/logprobs.py
Normal file
24
vllm/v1/sample/ops/logprobs.py
Normal file
@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Some utilities for logprobs, including logits."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def batched_count_greater_than(x: torch.Tensor,
|
||||
values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Counts elements in each row of x that are greater than the corresponding
|
||||
value in values. Use torch.compile to generate an optimized kernel for
|
||||
this function. otherwise, it will create additional copies of the input
|
||||
tensors and cause memory issues.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
|
||||
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
|
||||
"""
|
||||
return (x >= values).sum(-1)
|
||||
@ -9,6 +9,7 @@ from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
|
||||
@ -174,7 +175,7 @@ class Sampler(nn.Module):
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user