mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:45:01 +08:00
[V1][BugFix] Clean up rejection sampler & Fix warning msg (#13362)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
d67cc21b78
commit
69e1d23e1e
@ -3,7 +3,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.outputs import SamplerOutput
|
from vllm.v1.outputs import SamplerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
@ -19,27 +21,50 @@ INVALID_TOKEN_ID = -1
|
|||||||
|
|
||||||
class RejectionSampler(nn.Module):
|
class RejectionSampler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
if current_platform.is_cuda:
|
||||||
|
if is_flashinfer_available:
|
||||||
|
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||||
|
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||||
|
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
||||||
|
# default it is unused). For backward compatibility, we set
|
||||||
|
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
||||||
|
# interpret it differently in V0 and V1 samplers: In V0,
|
||||||
|
# None means False, while in V1, None means True. This is
|
||||||
|
# why we use the condition
|
||||||
|
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
||||||
|
logger.info("Using FlashInfer for rejection sampling.")
|
||||||
|
self.forward_method = self.flashinfer_sample
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"FlashInfer is available, but it is not enabled. "
|
||||||
|
"Falling back to the PyTorch-native implementation of "
|
||||||
|
"rejection sampling. For the best performance, "
|
||||||
|
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
||||||
|
self.forward_method = self.forward_native
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||||
|
"native implementation of rejection sampling. For the "
|
||||||
|
"best performance, please install FlashInfer.")
|
||||||
|
self.forward_method = self.forward_native
|
||||||
|
else:
|
||||||
|
self.forward_method = self.forward_native
|
||||||
|
|
||||||
def forward(self, logits: torch.Tensor,
|
def forward(self, logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||||
if not sampling_metadata.all_greedy:
|
if not sampling_metadata.all_greedy:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Only greedy sampling is supported by rejection sampler.")
|
"Currently, only greedy sampling is supported by "
|
||||||
|
"rejection sampler.")
|
||||||
|
return self.forward_method(logits, sampling_metadata)
|
||||||
|
|
||||||
if is_flashinfer_available:
|
|
||||||
logger.info("User FlashInfer for rejection sampling.")
|
|
||||||
return RejectionSampler.flashinfer_sample(logits,
|
|
||||||
sampling_metadata)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
|
||||||
"native implementation of rejection sampling.")
|
|
||||||
return RejectionSampler.greedy_sample_native(
|
|
||||||
logits, sampling_metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def flashinfer_sample(
|
def flashinfer_sample(
|
||||||
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
# NOTE: The following input preparationg can be moved
|
# NOTE: The following input preparationg can be moved
|
||||||
# to the model runner with a persistent manner for better
|
# to the model runner with a persistent manner for better
|
||||||
# performance.
|
# performance.
|
||||||
@ -71,10 +96,10 @@ class RejectionSampler(nn.Module):
|
|||||||
vocab_size = logits.size(-1)
|
vocab_size = logits.size(-1)
|
||||||
# NOTE: CPU <-> GPU synchronization happens here.
|
# NOTE: CPU <-> GPU synchronization happens here.
|
||||||
draft_token_ids = draft_token_ids.to(logits.device)
|
draft_token_ids = draft_token_ids.to(logits.device)
|
||||||
draft_probs = RejectionSampler._create_greedy_token_probs(
|
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
|
||||||
draft_token_ids, vocab_size, logits.device)
|
logits.device)
|
||||||
target_probs = RejectionSampler._create_greedy_token_probs(
|
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
|
||||||
target_token_ids, vocab_size, logits.device)
|
logits.device)
|
||||||
uniform_samples = torch.zeros(batch_size,
|
uniform_samples = torch.zeros(batch_size,
|
||||||
max_spec_len + 1,
|
max_spec_len + 1,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
@ -89,10 +114,11 @@ class RejectionSampler(nn.Module):
|
|||||||
logprobs_tensors=None)
|
logprobs_tensors=None)
|
||||||
|
|
||||||
# TODO: The following method can be optimized for better performance.
|
# TODO: The following method can be optimized for better performance.
|
||||||
@staticmethod
|
def forward_native(
|
||||||
def greedy_sample_native(
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||||
# Add 1 to include the 'bonus' token.
|
# Add 1 to include the 'bonus' token.
|
||||||
sample_lens = [x + 1 for x in spec_lens]
|
sample_lens = [x + 1 for x in spec_lens]
|
||||||
@ -137,9 +163,12 @@ class RejectionSampler(nn.Module):
|
|||||||
return SamplerOutput(sampled_token_ids=output_token_ids,
|
return SamplerOutput(sampled_token_ids=output_token_ids,
|
||||||
logprobs_tensors=None)
|
logprobs_tensors=None)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int,
|
def _create_greedy_token_probs(
|
||||||
out_device: torch.device) -> torch.Tensor:
|
token_ids: torch.Tensor,
|
||||||
|
vocab_size: int,
|
||||||
|
out_device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
batch_size, num_tokens = token_ids.shape
|
batch_size, num_tokens = token_ids.shape
|
||||||
|
|
||||||
token_probs = torch.zeros(batch_size,
|
token_probs = torch.zeros(batch_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user