improve logits bias (#19041)

This commit is contained in:
Yu Guo 2025-06-06 04:59:25 -07:00 committed by GitHub
parent 7353492a47
commit 8267f9916f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,7 @@
import torch
import torch.nn as nn
from vllm.utils import async_tensor_h2d, is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
@ -20,6 +21,7 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.pin_memory = is_pin_memory_available()
def forward(
self,
@ -232,6 +234,10 @@ class Sampler(nn.Module):
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
rows: list[int] = []
cols: list[int] = []
vals: list[float] = []
# Get vocabulary size from logits
vocab_size = logits.shape[-1]
@ -244,7 +250,16 @@ class Sampler(nn.Module):
f"token_id {token_id} in logit_bias contains "
f"out-of-vocab token id. Vocabulary size: "
f"{vocab_size}")
logits[i, token_id] += bias
rows.append(i)
cols.append(token_id)
vals.append(bias)
if rows:
indices = async_tensor_h2d([rows, cols], torch.int64,
logits.device, self.pin_memory)
values = async_tensor_h2d(vals, torch.float, logits.device,
self.pin_memory)
logits.index_put_(tuple(indices), values=values, accumulate=True)
return logits
def apply_allowed_token_ids(