mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:15:31 +08:00
[TPU] Remove TopKTopPSampler dependency for TPU sampler (#24391)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
62f66be1f7
commit
105d3d62ef
@ -6,8 +6,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
apply_top_k_top_p_tpu)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
# isort: off
|
||||
from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as
|
||||
apply_top_k_top_p_tpu)
|
||||
# isort: on
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||
|
||||
@ -73,10 +73,8 @@ class TopKTopPSampler(nn.Module):
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
if current_platform.is_tpu():
|
||||
self.apply_top_k_top_p = apply_top_k_top_p_tpu
|
||||
else:
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@ -125,53 +123,6 @@ class TopKTopPSampler(nn.Module):
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||
|
||||
|
||||
def apply_top_k_top_p_tpu(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k and top-p optimized for TPU.
|
||||
|
||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||
This is achieved by finding a "cut-off" element in the original logit, and
|
||||
after thresholding the logit using this cut-off, the remaining elements
|
||||
shall constitute the top-p set.
|
||||
|
||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||
logit), all tie elements are included in the top-p set. In other words,
|
||||
this function does not break ties. Instead, these tie tokens have equal
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
|
||||
@ -2,11 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampler layer implementing TPU supported operations."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
@ -17,7 +18,6 @@ class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
# TODO(houseroad): Add support for logprobs_mode.
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -65,13 +65,17 @@ class Sampler(nn.Module):
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled, _ = self.topk_topp_sampler(
|
||||
logits = apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
# Random sample.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
random_sampled = self.random_sample(probs,
|
||||
sampling_metadata.generators)
|
||||
|
||||
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled, random_sampled)
|
||||
return sampled
|
||||
@ -144,3 +148,66 @@ class Sampler(nn.Module):
|
||||
# Apply mask using boolean indexing (xla friendly)
|
||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
def random_sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
q.exponential_()
|
||||
if generators:
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k and top-p optimized for TPU.
|
||||
|
||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||
This is achieved by finding a "cut-off" element in the original logit, and
|
||||
after thresholding the logit using this cut-off, the remaining elements
|
||||
shall constitute the top-p set.
|
||||
|
||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||
logit), all tie elements are included in the top-p set. In other words,
|
||||
this function does not break ties. Instead, these tie tokens have equal
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user