mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:45:00 +08:00
[V1][BugFix] Clean up rejection sampler & Fix warning msg (#13362)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
d67cc21b78
commit
69e1d23e1e
@ -3,7 +3,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
@ -19,27 +21,50 @@ INVALID_TOKEN_ID = -1
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda:
|
||||
if is_flashinfer_available:
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
||||
# default it is unused). For backward compatibility, we set
|
||||
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
||||
# interpret it differently in V0 and V1 samplers: In V0,
|
||||
# None means False, while in V1, None means True. This is
|
||||
# why we use the condition
|
||||
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
||||
logger.info("Using FlashInfer for rejection sampling.")
|
||||
self.forward_method = self.flashinfer_sample
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is available, but it is not enabled. "
|
||||
"Falling back to the PyTorch-native implementation of "
|
||||
"rejection sampling. For the best performance, "
|
||||
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
||||
self.forward_method = self.forward_native
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||
"native implementation of rejection sampling. For the "
|
||||
"best performance, please install FlashInfer.")
|
||||
self.forward_method = self.forward_native
|
||||
else:
|
||||
self.forward_method = self.forward_native
|
||||
|
||||
def forward(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
if not sampling_metadata.all_greedy:
|
||||
raise NotImplementedError(
|
||||
"Only greedy sampling is supported by rejection sampler.")
|
||||
"Currently, only greedy sampling is supported by "
|
||||
"rejection sampler.")
|
||||
return self.forward_method(logits, sampling_metadata)
|
||||
|
||||
if is_flashinfer_available:
|
||||
logger.info("User FlashInfer for rejection sampling.")
|
||||
return RejectionSampler.flashinfer_sample(logits,
|
||||
sampling_metadata)
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||
"native implementation of rejection sampling.")
|
||||
return RejectionSampler.greedy_sample_native(
|
||||
logits, sampling_metadata)
|
||||
|
||||
@staticmethod
|
||||
def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# NOTE: The following input preparationg can be moved
|
||||
# to the model runner with a persistent manner for better
|
||||
# performance.
|
||||
@ -71,10 +96,10 @@ class RejectionSampler(nn.Module):
|
||||
vocab_size = logits.size(-1)
|
||||
# NOTE: CPU <-> GPU synchronization happens here.
|
||||
draft_token_ids = draft_token_ids.to(logits.device)
|
||||
draft_probs = RejectionSampler._create_greedy_token_probs(
|
||||
draft_token_ids, vocab_size, logits.device)
|
||||
target_probs = RejectionSampler._create_greedy_token_probs(
|
||||
target_token_ids, vocab_size, logits.device)
|
||||
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
|
||||
logits.device)
|
||||
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
|
||||
logits.device)
|
||||
uniform_samples = torch.zeros(batch_size,
|
||||
max_spec_len + 1,
|
||||
device=logits.device)
|
||||
@ -89,10 +114,11 @@ class RejectionSampler(nn.Module):
|
||||
logprobs_tensors=None)
|
||||
|
||||
# TODO: The following method can be optimized for better performance.
|
||||
@staticmethod
|
||||
def greedy_sample_native(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||
# Add 1 to include the 'bonus' token.
|
||||
sample_lens = [x + 1 for x in spec_lens]
|
||||
@ -137,24 +163,27 @@ class RejectionSampler(nn.Module):
|
||||
return SamplerOutput(sampled_token_ids=output_token_ids,
|
||||
logprobs_tensors=None)
|
||||
|
||||
@staticmethod
|
||||
def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int,
|
||||
out_device: torch.device) -> torch.Tensor:
|
||||
batch_size, num_tokens = token_ids.shape
|
||||
|
||||
token_probs = torch.zeros(batch_size,
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float,
|
||||
device=out_device)
|
||||
def _create_greedy_token_probs(
|
||||
token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
out_device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens = token_ids.shape
|
||||
|
||||
# Ignore INVALID_TOKEN_ID.
|
||||
valid_mask = (token_ids != INVALID_TOKEN_ID)
|
||||
valid_indices = token_ids.clone()
|
||||
valid_indices[~valid_mask] = 0
|
||||
token_probs = torch.zeros(batch_size,
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float,
|
||||
device=out_device)
|
||||
|
||||
token_probs.scatter_(dim=2,
|
||||
index=valid_indices.unsqueeze(-1),
|
||||
src=valid_mask.unsqueeze(-1).float())
|
||||
# Ignore INVALID_TOKEN_ID.
|
||||
valid_mask = (token_ids != INVALID_TOKEN_ID)
|
||||
valid_indices = token_ids.clone()
|
||||
valid_indices[~valid_mask] = 0
|
||||
|
||||
return token_probs
|
||||
token_probs.scatter_(dim=2,
|
||||
index=valid_indices.unsqueeze(-1),
|
||||
src=valid_mask.unsqueeze(-1).float())
|
||||
|
||||
return token_probs
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user