mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:45:50 +08:00
[TPU] Remove TopKTopPSampler dependency for TPU sampler (#24391)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
62f66be1f7
commit
105d3d62ef
@ -6,8 +6,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
apply_top_k_top_p_tpu)
|
|
||||||
|
# isort: off
|
||||||
|
from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as
|
||||||
|
apply_top_k_top_p_tpu)
|
||||||
|
# isort: on
|
||||||
|
|
||||||
if not current_platform.is_tpu():
|
if not current_platform.is_tpu():
|
||||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||||
|
|||||||
@ -73,10 +73,8 @@ class TopKTopPSampler(nn.Module):
|
|||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
else:
|
else:
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
if current_platform.is_tpu():
|
|
||||||
self.apply_top_k_top_p = apply_top_k_top_p_tpu
|
self.apply_top_k_top_p = apply_top_k_top_p
|
||||||
else:
|
|
||||||
self.apply_top_k_top_p = apply_top_k_top_p
|
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
@ -125,53 +123,6 @@ class TopKTopPSampler(nn.Module):
|
|||||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p_tpu(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
p: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply top-k and top-p optimized for TPU.
|
|
||||||
|
|
||||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
|
||||||
This is achieved by finding a "cut-off" element in the original logit, and
|
|
||||||
after thresholding the logit using this cut-off, the remaining elements
|
|
||||||
shall constitute the top-p set.
|
|
||||||
|
|
||||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
|
||||||
logit), all tie elements are included in the top-p set. In other words,
|
|
||||||
this function does not break ties. Instead, these tie tokens have equal
|
|
||||||
chance of being chosen during final sampling, so we can consider the tie
|
|
||||||
being broken then.
|
|
||||||
"""
|
|
||||||
probs = logits.softmax(dim=-1)
|
|
||||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
|
||||||
|
|
||||||
if k is not None:
|
|
||||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
|
||||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
|
||||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
|
||||||
|
|
||||||
# Make sure the no top-k rows are no-op.
|
|
||||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
|
||||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
|
||||||
|
|
||||||
elements_to_discard = probs < top_k_cutoff
|
|
||||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
|
||||||
|
|
||||||
if p is not None:
|
|
||||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
|
||||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
|
||||||
top_p_mask[:, -1] = False # at least one
|
|
||||||
|
|
||||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
|
||||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
|
||||||
elements_to_discard = probs < top_p_cutoff
|
|
||||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p(
|
def apply_top_k_top_p(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
|
|||||||
@ -2,11 +2,12 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Sampler layer implementing TPU supported operations."""
|
"""Sampler layer implementing TPU supported operations."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
|
||||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
@ -17,7 +18,6 @@ class Sampler(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# TODO(houseroad): Add support for logprobs_mode.
|
# TODO(houseroad): Add support for logprobs_mode.
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topk_topp_sampler = TopKTopPSampler()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -65,13 +65,17 @@ class Sampler(nn.Module):
|
|||||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||||
|
|
||||||
# Apply top_k and/or top_p.
|
# Apply top_k and/or top_p.
|
||||||
random_sampled, _ = self.topk_topp_sampler(
|
logits = apply_top_k_top_p(
|
||||||
logits,
|
logits,
|
||||||
sampling_metadata.generators,
|
|
||||||
sampling_metadata.top_k,
|
sampling_metadata.top_k,
|
||||||
sampling_metadata.top_p,
|
sampling_metadata.top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Random sample.
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
random_sampled = self.random_sample(probs,
|
||||||
|
sampling_metadata.generators)
|
||||||
|
|
||||||
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||||
greedy_sampled, random_sampled)
|
greedy_sampled, random_sampled)
|
||||||
return sampled
|
return sampled
|
||||||
@ -144,3 +148,66 @@ class Sampler(nn.Module):
|
|||||||
# Apply mask using boolean indexing (xla friendly)
|
# Apply mask using boolean indexing (xla friendly)
|
||||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def random_sample(
|
||||||
|
self,
|
||||||
|
probs: torch.Tensor,
|
||||||
|
generators: dict[int, torch.Generator],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
q = torch.empty_like(probs)
|
||||||
|
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||||
|
# which is the common case, we first assume that every request does
|
||||||
|
# not have its own seed. Then, we overwrite the values for the requests
|
||||||
|
# that have their own seeds.
|
||||||
|
q.exponential_()
|
||||||
|
if generators:
|
||||||
|
for i, generator in generators.items():
|
||||||
|
q[i].exponential_(generator=generator)
|
||||||
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_top_k_top_p(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
p: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply top-k and top-p optimized for TPU.
|
||||||
|
|
||||||
|
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||||
|
This is achieved by finding a "cut-off" element in the original logit, and
|
||||||
|
after thresholding the logit using this cut-off, the remaining elements
|
||||||
|
shall constitute the top-p set.
|
||||||
|
|
||||||
|
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||||
|
logit), all tie elements are included in the top-p set. In other words,
|
||||||
|
this function does not break ties. Instead, these tie tokens have equal
|
||||||
|
chance of being chosen during final sampling, so we can consider the tie
|
||||||
|
being broken then.
|
||||||
|
"""
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||||
|
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||||
|
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||||
|
|
||||||
|
# Make sure the no top-k rows are no-op.
|
||||||
|
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||||
|
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||||
|
|
||||||
|
elements_to_discard = probs < top_k_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
if p is not None:
|
||||||
|
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||||
|
top_p_mask[:, -1] = False # at least one
|
||||||
|
|
||||||
|
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||||
|
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||||
|
elements_to_discard = probs < top_p_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user