diff --git a/vllm/envs.py b/vllm/envs.py index 165cd32721fe5..c5688a72e11d5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -163,6 +163,7 @@ if TYPE_CHECKING: VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False @@ -1155,6 +1156,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_TRTLLM_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set to 1, when we use fp8 kv, we do not quantize Q to fp8 + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION": + lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))), + # If set, it means we pre-downloaded cubin files and flashinfer will # read the cubin files directly. "VLLM_HAS_FLASHINFER_CUBIN": @@ -1310,6 +1315,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "VLLM_ROCM_USE_AITER", "VLLM_ROCM_USE_AITER_PAGED_ATTN", "VLLM_ROCM_USE_AITER_LINEAR", diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index fab134733d4fd..83ec65c9b4594 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -200,11 +200,6 @@ def use_trtllm_attention( logger.info_once("Using TRTLLM attention (query is quantized).") return True - # TRTLLM prefill attention does not support FP8 kv cache with - # non-quantized query - if is_prefill and kv_cache_dtype.startswith("fp8"): - return False - # If sinks are being used, we must use TRTLLM attention as it's # the only backend that supports them if has_sinks: @@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm( return output +@functools.cache +def flashinfer_disable_q_quantization() -> bool: + """Cache result which only depends on the environment""" + return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 9e05cc8ab2f18..98a4cf38bc195 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (supports_trtllm_attention, +from vllm.utils.flashinfer import (flashinfer_disable_q_quantization, + supports_trtllm_attention, use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block @@ -48,8 +49,89 @@ FP4_DTYPE = torch.uint8 logger = init_logger(__name__) -class FlashInferBackend(AttentionBackend): +@triton.jit +def _trtllm_prefill_attn_kvfp8_dequant( + kv_cache_ptr, + block_tables_prefill_ptr, + block_table_stride, + mock_kv_cache_ptr, + k_scale_ptr, + v_scale_ptr, + K_CACHE_STRIDE: tl.constexpr, + KV_CACHE_STRIDE: tl.constexpr, +): + batch_idx = tl.program_id(0).to(tl.int64) + mock_block_table_idx = tl.program_id(1).to(tl.int64) + orig_page_num = tl.load(block_tables_prefill_ptr + + batch_idx * block_table_stride + + mock_block_table_idx).to(tl.int64) + if orig_page_num <= 0: + return + dequant_dtype = mock_kv_cache_ptr.dtype.element_ty + # Dequantize K + k_scale_val = tl.load(k_scale_ptr) + offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val + mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx + + 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + # Dequantize V + v_scale_val = tl.load(v_scale_ptr) + offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + + tl.arange(0, K_CACHE_STRIDE)) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val + mock_cache_offset = ( + (batch_idx * block_table_stride + mock_block_table_idx + 1) * + KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + +def trtllm_prefill_attn_kvfp8_dequant( + kv_cache: torch.Tensor, + block_tables_prefill: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + dequant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_of_page_per_token = block_tables_prefill.shape + s = kv_cache.shape + assert s[1] == 2 + assert dequant_dtype in (torch.bfloat16, torch.float16) + k_cache_stride = s[2] * s[3] * s[4] + kv_cache_stride = k_cache_stride * s[1] + new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) + # mock kv cache contains just the pages needed by this prefill + mock_kv_cache = torch.empty(new_s, + dtype=dequant_dtype, + device=kv_cache.device) + # we simply sequentially index the pages needed by this prefill + mock_block_table = torch.arange( + start=1, + end=batch_size * num_of_page_per_token + 1, + dtype=torch.int32, + device=block_tables_prefill.device, + ).reshape(batch_size, num_of_page_per_token) + grid = (batch_size, num_of_page_per_token) + _trtllm_prefill_attn_kvfp8_dequant[grid]( + kv_cache, + block_tables_prefill, + num_of_page_per_token, + mock_kv_cache, + k_scale, + v_scale, + k_cache_stride, + kv_cache_stride, + ) + return mock_kv_cache, mock_block_table + + +class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True @classmethod @@ -122,7 +204,6 @@ class FlashInferBackend(AttentionBackend): @dataclass class FlashInferMetadata: - num_actual_tokens: int # Number of tokens excluding padding. # The data type of the query @@ -175,8 +256,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL + self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ + decode_mode() == CUDAGraphMode.FULL) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. @@ -201,7 +282,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype - if supports_trtllm_attention()[0]: + if supports_trtllm_attention()[0] and \ + not flashinfer_disable_q_quantization(): self.q_data_type = self.kv_cache_dtype else: self.q_data_type = self.model_config.dtype @@ -795,11 +877,29 @@ class FlashInferImpl(AttentionImpl): assert self.o_sf_scale is None out = output[num_decode_tokens:] + if attn_metadata.q_data_type != FP8_DTYPE \ + and self.kv_cache_dtype.startswith("fp8"): + # TRTLLM prefill attention does not support BF16 Q + # and fp8 kv cache. So to enable prefill attention + # with fp8 kv cache, we can construct a mock block + # and mock kv cache with BF16 KV involved in the prefill + mock_kv_cache, mock_block_table = ( + trtllm_prefill_attn_kvfp8_dequant( + kv_cache_permute, + block_tables_prefill, + layer._k_scale, + layer._v_scale, + attn_metadata.q_data_type, + )) + else: + mock_kv_cache = kv_cache_permute + mock_block_table = block_tables_prefill + trtllm_batch_context_with_kv_cache( query=prefill_query, - kv_cache=kv_cache_permute, + kv_cache=mock_kv_cache, workspace_buffer=workspace_buffer, - block_tables=block_tables_prefill, + block_tables=mock_block_table, seq_lens=seq_lens_prefill, max_q_len=attn_metadata.max_q_len, max_kv_len=attn_metadata.max_seq_len, @@ -837,7 +937,7 @@ class FlashInferImpl(AttentionImpl): decode_query = decode_query.contiguous() workspace_buffer = decode_wrapper._float_workspace_buffer block_tables_decode = attn_metadata.\ - block_table_tensor[:num_decode_tokens] + block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND