mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 19:37:57 +08:00
[v1][sampler] Inplace logprobs comparison to get the token rank (#21283)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
parent
0ec82edda5
commit
8d0a01a5f2
24
vllm/v1/sample/ops/logprobs.py
Normal file
24
vllm/v1/sample/ops/logprobs.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Some utilities for logprobs, including logits."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def batched_count_greater_than(x: torch.Tensor,
|
||||||
|
values: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Counts elements in each row of x that are greater than the corresponding
|
||||||
|
value in values. Use torch.compile to generate an optimized kernel for
|
||||||
|
this function. otherwise, it will create additional copies of the input
|
||||||
|
tensors and cause memory issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
|
||||||
|
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
|
||||||
|
"""
|
||||||
|
return (x >= values).sum(-1)
|
||||||
@ -9,6 +9,7 @@ from vllm.utils import is_pin_memory_available
|
|||||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||||
|
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
|
||||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||||
|
|
||||||
@ -174,7 +175,7 @@ class Sampler(nn.Module):
|
|||||||
token_logprobs = logprobs.gather(-1, token_ids)
|
token_logprobs = logprobs.gather(-1, token_ids)
|
||||||
|
|
||||||
# Compute the ranks of the actual token.
|
# Compute the ranks of the actual token.
|
||||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
|
||||||
|
|
||||||
# Concatenate together with the topk.
|
# Concatenate together with the topk.
|
||||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user