diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index b3f81715461b1..72b54b40a2d1e 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -110,7 +110,7 @@ def benchmark_decode( wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, - use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), + use_tensor_cores=True, ) wrapper.plan( kv_indptr, diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index be78f0e4fcc62..a821a74aba93d 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv( workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) - ) + use_tensor_cores=True) wrapper.plan( kv_indptr, kv_indices, @@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 - use_tensor_cores = (num_query_heads // num_kv_heads) > 4 + use_tensor_cores = True kv_cache_dtype = torch.float8_e4m3fn query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 619822f3ee43b..69e44264cd440 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -136,9 +136,7 @@ def test_flashinfer_trtllm_decode_with_baseline( # Baseline Decode wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_qo_heads // num_kv_heads) > 4)) + workspace_buffer, kv_layout, use_tensor_cores=True) wrapper.plan(kv_indptr, kv_indices, kv_last_page_lens, diff --git a/vllm/envs.py b/vllm/envs.py index a844aa8af61e3..296c1730892da 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -42,7 +42,6 @@ if TYPE_CHECKING: VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None - VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -465,11 +464,6 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, - # If set, vllm will force flashinfer to use tensor cores; - # otherwise will use heuristic based on model architecture. - "VLLM_FLASHINFER_FORCE_TENSOR_CORES": - lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), - # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), @@ -1221,7 +1215,6 @@ def compute_hash() -> str: "VLLM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ATTENTION_BACKEND", "VLLM_USE_FLASHINFER_SAMPLER", - "VLLM_FLASHINFER_FORCE_TENSOR_CORES", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", "VLLM_USE_TRTLLM_FP4_GEMM", diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8a25088848a44..1e6e3f1d0abf4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -13,7 +13,6 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache -import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) @@ -228,8 +227,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.q_data_type = self.kv_cache_dtype else: self.kv_cache_dtype = self.kv_cache_spec.dtype - self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or - (self.num_qo_heads // self.num_kv_heads > 4)) self._cascade_wrapper = None # Wrapper for cascade attention @@ -308,7 +305,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indptr_buffer=paged_kv_indptr, paged_kv_indices_buffer=paged_kv_indices, paged_kv_last_page_len_buffer=paged_kv_last_page_len, - use_tensor_cores=self.use_tensor_cores) + # Tensor cores are enabled by default because the perf would be + # atleast as good as cuda cores for all attention ops in latest + # gpus. + use_tensor_cores=True, + ) # save the decode wrapper if use_cudagraph: @@ -984,52 +985,29 @@ def fast_plan_decode( self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) - if self.use_tensor_cores: - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - try: - # Make sure we pass exactly 15 arguments for tensor core version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_cpu, - seq_lens_cpu, - batch_size, # total_num_rows - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - ) - except Exception as e: - raise RuntimeError(f"Error in tensor core plan: {e}") from e - else: - try: - # Make sure we pass exactly 15 arguments for standard version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr_cpu, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - window_left, - logits_soft_cap, - head_dim, - head_dim, - torch.empty(0, dtype=q_data_type), - torch.empty(0, dtype=kv_data_type), - ) - except Exception as e: - raise RuntimeError(f"Error in standard plan: {e}") from e + try: + # Make sure we pass exactly 15 arguments for tensor core version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_cpu, + seq_lens_cpu, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + ) + except Exception as e: + raise RuntimeError(f"Error in tensor core plan: {e}") from e self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left