From fec2b341ad76d322ffe956fc9581e958bd6887bc Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 17 Oct 2025 12:48:18 +0800 Subject: [PATCH] [Kernel] Lazy import FlashInfer (#26977) --- tests/v1/sample/test_topk_topp_sampler.py | 17 +++++---- vllm/v1/sample/ops/topk_topp_sampler.py | 46 ++++++++--------------- 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index c70cbebe22caa..f50ef61022040 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -5,20 +5,13 @@ import torch from torch import Generator from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import ( - apply_top_k_top_p, - is_flashinfer_available, -) +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p DEVICE = current_platform.device_type BATCH_SIZE = 1024 VOCAB_SIZE = 128 * 1024 -FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available -if is_flashinfer_available: - from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs - @pytest.fixture(autouse=True) def reset_default_device(): @@ -65,6 +58,14 @@ def test_flashinfer_sampler(): sampling results due to randomness), so we will compare the probability renormed consequently by top-k and then top-p of FlashInfer implementation. """ + try: + from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs + + is_flashinfer_available = True + except ImportError: + is_flashinfer_available = False + + FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available if not FLASHINFER_ENABLED: pytest.skip("FlashInfer not installed or not available on this platform.") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 950cde82fb9d9..cefe372decf9b 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -13,13 +13,6 @@ from vllm.platforms import CpuArchEnum, current_platform logger = init_logger(__name__) -try: - import flashinfer.sampling - - is_flashinfer_available = True -except ImportError: - is_flashinfer_available = False - class TopKTopPSampler(nn.Module): """ @@ -38,32 +31,18 @@ class TopKTopPSampler(nn.Module): logprobs_mode not in ("processed_logits", "processed_logprobs") and current_platform.is_cuda() ): - if is_flashinfer_available: - flashinfer_version = flashinfer.__version__ - if version.parse(flashinfer_version) < version.parse("0.2.3"): - logger.warning_once( - "FlashInfer version >= 0.2.3 required. " - "Falling back to default sampling implementation." - ) - self.forward = self.forward_native - elif envs.VLLM_USE_FLASHINFER_SAMPLER: - # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1. - logger.info_once("Using FlashInfer for top-p & top-k sampling.") - self.forward = self.forward_cuda - else: - logger.debug_once( - "FlashInfer top-p/top-k sampling is available but disabled " - "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in " - "after verifying accuracy for your workloads." - ) - self.forward = self.forward_native + if envs.VLLM_USE_FLASHINFER_SAMPLER: + # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1. + logger.info_once("Using FlashInfer for top-p & top-k sampling.") + self.forward = self.forward_cuda else: - logger.warning_once( - "FlashInfer is not available. Falling back to the PyTorch-" - "native implementation of top-p & top-k sampling. For the " - "best performance, please install FlashInfer." + logger.debug_once( + "FlashInfer top-p/top-k sampling is available but disabled " + "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in " + "after verifying accuracy for your workloads." ) self.forward = self.forward_native + elif current_platform.is_cpu(): arch = current_platform.get_cpu_architecture() # Fall back to native implementation for POWERPC and RISCV. @@ -278,6 +257,13 @@ def flashinfer_sample( does not. Call this function at the end of the forward pass to minimize the synchronization overhead. """ + import flashinfer + + if version.parse(flashinfer.__version__) < version.parse("0.2.3"): + raise ImportError( + "FlashInfer version >= 0.2.3 required for top-k and top-p sampling. " + ) + assert not (k is None and p is None) if k is None: # Top-p only.