diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 8ba3c2087a5cb..6bc0cecdd4940 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +from vllm.utils import async_tensor_h2d, is_pin_memory_available from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words @@ -20,6 +21,7 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.topk_topp_sampler = TopKTopPSampler() + self.pin_memory = is_pin_memory_available() def forward( self, @@ -232,6 +234,10 @@ class Sampler(nn.Module): # One idea is implement this as a PyTorch C++ op, and we may # even optimize the logit_bias layout. + rows: list[int] = [] + cols: list[int] = [] + vals: list[float] = [] + # Get vocabulary size from logits vocab_size = logits.shape[-1] @@ -244,7 +250,16 @@ class Sampler(nn.Module): f"token_id {token_id} in logit_bias contains " f"out-of-vocab token id. Vocabulary size: " f"{vocab_size}") - logits[i, token_id] += bias + rows.append(i) + cols.append(token_id) + vals.append(bias) + + if rows: + indices = async_tensor_h2d([rows, cols], torch.int64, + logits.device, self.pin_memory) + values = async_tensor_h2d(vals, torch.float, logits.device, + self.pin_memory) + logits.index_put_(tuple(indices), values=values, accumulate=True) return logits def apply_allowed_token_ids(