mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:06:19 +08:00
improve logits bias (#19041)
This commit is contained in:
parent
7353492a47
commit
8267f9916f
@ -5,6 +5,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import async_tensor_h2d, is_pin_memory_available
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
@ -20,6 +21,7 @@ class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -232,6 +234,10 @@ class Sampler(nn.Module):
|
||||
# One idea is implement this as a PyTorch C++ op, and we may
|
||||
# even optimize the logit_bias layout.
|
||||
|
||||
rows: list[int] = []
|
||||
cols: list[int] = []
|
||||
vals: list[float] = []
|
||||
|
||||
# Get vocabulary size from logits
|
||||
vocab_size = logits.shape[-1]
|
||||
|
||||
@ -244,7 +250,16 @@ class Sampler(nn.Module):
|
||||
f"token_id {token_id} in logit_bias contains "
|
||||
f"out-of-vocab token id. Vocabulary size: "
|
||||
f"{vocab_size}")
|
||||
logits[i, token_id] += bias
|
||||
rows.append(i)
|
||||
cols.append(token_id)
|
||||
vals.append(bias)
|
||||
|
||||
if rows:
|
||||
indices = async_tensor_h2d([rows, cols], torch.int64,
|
||||
logits.device, self.pin_memory)
|
||||
values = async_tensor_h2d(vals, torch.float, logits.device,
|
||||
self.pin_memory)
|
||||
logits.index_put_(tuple(indices), values=values, accumulate=True)
|
||||
return logits
|
||||
|
||||
def apply_allowed_token_ids(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user