From 32ef4983cd029d613172dbcf1edf91e62920bbc8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 13 Mar 2025 20:40:35 -0700 Subject: [PATCH] [V1] Temporarily disable FlashInfer Rejection Sampler (#14788) Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/topk_topp_sampler.py | 2 +- vllm/v1/sample/rejection_sampler.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 1bb950be822c1..7d70e839b6f4e 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module): def __init__(self): super().__init__() - if current_platform.is_cuda: + if current_platform.is_cuda(): if is_flashinfer_available: if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 80a4b24186ab7..ea7f3353c115f 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -24,9 +24,18 @@ class RejectionSampler(nn.Module): def __init__(self): super().__init__() - if current_platform.is_cuda: + if current_platform.is_cuda(): if is_flashinfer_available: if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: + # FIXME(woosuk): Currently, we have errors when using + # FlashInfer for rejection sampling. As a workaround, we + # disable FlashInfer for rejection sampling by default. + logger.info("Currently, FlashInfer rejection sampler is " + "disabled because of a bug. Falling back to " + "the PyTorch-native implementation of " + "rejection sampling.") + self.forward_method = self.forward_native + # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by # default it is unused). For backward compatibility, we set @@ -35,8 +44,8 @@ class RejectionSampler(nn.Module): # None means False, while in V1, None means True. This is # why we use the condition # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. - logger.info("Using FlashInfer for rejection sampling.") - self.forward_method = self.flashinfer_sample + # logger.info("Using FlashInfer for rejection sampling.") + # self.forward_method = self.flashinfer_sample else: logger.warning( "FlashInfer is available, but it is not enabled. "