mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 02:45:01 +08:00
improve logits bias (#19041)
This commit is contained in:
parent
7353492a47
commit
8267f9916f
@ -5,6 +5,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.utils import async_tensor_h2d, is_pin_memory_available
|
||||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||||
@ -20,6 +21,7 @@ class Sampler(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topk_topp_sampler = TopKTopPSampler()
|
self.topk_topp_sampler = TopKTopPSampler()
|
||||||
|
self.pin_memory = is_pin_memory_available()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -232,6 +234,10 @@ class Sampler(nn.Module):
|
|||||||
# One idea is implement this as a PyTorch C++ op, and we may
|
# One idea is implement this as a PyTorch C++ op, and we may
|
||||||
# even optimize the logit_bias layout.
|
# even optimize the logit_bias layout.
|
||||||
|
|
||||||
|
rows: list[int] = []
|
||||||
|
cols: list[int] = []
|
||||||
|
vals: list[float] = []
|
||||||
|
|
||||||
# Get vocabulary size from logits
|
# Get vocabulary size from logits
|
||||||
vocab_size = logits.shape[-1]
|
vocab_size = logits.shape[-1]
|
||||||
|
|
||||||
@ -244,7 +250,16 @@ class Sampler(nn.Module):
|
|||||||
f"token_id {token_id} in logit_bias contains "
|
f"token_id {token_id} in logit_bias contains "
|
||||||
f"out-of-vocab token id. Vocabulary size: "
|
f"out-of-vocab token id. Vocabulary size: "
|
||||||
f"{vocab_size}")
|
f"{vocab_size}")
|
||||||
logits[i, token_id] += bias
|
rows.append(i)
|
||||||
|
cols.append(token_id)
|
||||||
|
vals.append(bias)
|
||||||
|
|
||||||
|
if rows:
|
||||||
|
indices = async_tensor_h2d([rows, cols], torch.int64,
|
||||||
|
logits.device, self.pin_memory)
|
||||||
|
values = async_tensor_h2d(vals, torch.float, logits.device,
|
||||||
|
self.pin_memory)
|
||||||
|
logits.index_put_(tuple(indices), values=values, accumulate=True)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def apply_allowed_token_ids(
|
def apply_allowed_token_ids(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user