diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index f2d01099097a..afd7c47e8ac0 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9 + GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 95424e25732b..572563c0bd82 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -15,6 +15,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] +QDTYPES = [None, torch.float8_e4m3fn] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] @@ -85,6 +86,7 @@ def ref_paged_attn( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("fa_version", [2, 3]) +@pytest.mark.parametrize("q_dtype", QDTYPES) @torch.inference_mode() def test_flash_attn_with_paged_kv( use_out: bool, @@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv( num_blocks: int, sliding_window: Optional[int], fa_version: int, + q_dtype: Optional[torch.dtype], ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): pytest.skip(f"Flash attention version {fa_version} not supported due " f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): + pytest.skip("Flash attention with quantized inputs is only " + "supported on version 3 with bfloat16 base type") current_platform.seed_everything(0) num_seqs = len(kv_lens) @@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv( q = query.unsqueeze(1) out = torch.empty_like(q) if use_out else None + + maybe_quantized_query = q + maybe_quantized_key_cache = key_cache + maybe_quantized_value_cache = value_cache + q_descale = None + k_descale = None + v_descale = None + if q_dtype is not None: + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor + maybe_quantized_query = query.to(q_dtype) + maybe_quantized_key_cache = key_cache.to(q_dtype) + maybe_quantized_value_cache = value_cache.to(q_dtype) + + scale_shape = (num_seqs, num_kv_heads) + q_descale = torch.ones(scale_shape, dtype=torch.float32) + k_descale = torch.ones(scale_shape, dtype=torch.float32) + v_descale = torch.ones(scale_shape, dtype=torch.float32) + output = flash_attn_with_kvcache( - q=q, - k_cache=key_cache, - v_cache=value_cache, + q=maybe_quantized_query, + k_cache=maybe_quantized_key_cache, + v_cache=maybe_quantized_value_cache, out=out, softmax_scale=scale, causal=True, @@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, window_size=window_size, fa_version=fa_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, ) output = output if not use_out else out output = output.squeeze(1) + atol, rtol = 1.5e-2, 1e-2 + if q_dtype is not None: + atol, rtol = 1.5e-1, 1.5e-1 + ref_output = ref_paged_attn(query=query, key_cache=key_cache, value_cache=value_cache, @@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv( scale=scale, soft_cap=soft_cap, sliding_window=sliding_window) - torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("fa_version", [2, 3]) +@pytest.mark.parametrize("q_dtype", QDTYPES) @torch.inference_mode() def test_varlen_with_paged_kv( use_out: bool, @@ -183,11 +215,15 @@ def test_varlen_with_paged_kv( soft_cap: Optional[float], num_blocks: int, fa_version: int, + q_dtype: Optional[torch.dtype], ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): pytest.skip(f"Flash attention version {fa_version} not supported due " f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): + pytest.skip("Flash attention with quantized inputs is only " + "supported on version 3 with bfloat16 base type") current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] @@ -223,10 +259,28 @@ def test_varlen_with_paged_kv( dtype=torch.int32) out = torch.empty_like(query) if use_out else None + + maybe_quantized_query = query + maybe_quantized_key_cache = key_cache + maybe_quantized_value_cache = value_cache + q_descale = None + k_descale = None + v_descale = None + if q_dtype is not None: + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor + maybe_quantized_query = query.to(q_dtype) + maybe_quantized_key_cache = key_cache.to(q_dtype) + maybe_quantized_value_cache = value_cache.to(q_dtype) + + scale_shape = (num_seqs, num_kv_heads) + q_descale = torch.ones(scale_shape, dtype=torch.float32) + k_descale = torch.ones(scale_shape, dtype=torch.float32) + v_descale = torch.ones(scale_shape, dtype=torch.float32) + output = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, + q=maybe_quantized_query, + k=maybe_quantized_key_cache, + v=maybe_quantized_value_cache, out=out, cu_seqlens_q=cu_query_lens, seqused_k=kv_lens, @@ -238,6 +292,9 @@ def test_varlen_with_paged_kv( block_table=block_tables, softcap=soft_cap if soft_cap is not None else 0, fa_version=fa_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, ) output = output if not use_out else out @@ -252,5 +309,8 @@ def test_varlen_with_paged_kv( sliding_window=sliding_window, soft_cap=soft_cap, ) - torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ + atol, rtol = 1.5e-2, 1e-2 + if q_dtype is not None: + atol, rtol = 1.5e-1, 1.5e-1 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 89229e7b87a0..85c5715faba7 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -4,12 +4,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) -from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend __all__ = [ - "Attention", "AttentionBackend", "AttentionMetadata", "AttentionType", - "AttentionMetadataBuilder", "Attention", "AttentionState", - "get_attn_backend", "get_flash_attn_version" + "Attention", + "AttentionBackend", + "AttentionMetadata", + "AttentionType", + "AttentionMetadataBuilder", + "Attention", + "AttentionState", + "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0cd95e0749d1..82d60f9da7da 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -232,6 +232,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]): class AttentionLayer(Protocol): + _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor _k_scale_float: float diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0e331efa6a39..e981ac780b00 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -19,10 +19,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, # yapf: enable from vllm.attention.backends.utils import ( PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_flash_attn_version, - get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, - is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, - is_block_tables_empty) + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -630,9 +630,11 @@ class FlashAttentionImpl(AttentionImpl): self.sliding_window = ((sliding_window - 1, 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype - if is_quantized_kv_cache(self.kv_cache_dtype): + self.vllm_flash_attn_version = get_flash_attn_version() + if (is_quantized_kv_cache(self.kv_cache_dtype) + and self.vllm_flash_attn_version != 3): raise NotImplementedError( - "FlashAttention with FP8 KV cache not yet supported") + "Only FlashAttention3 supports FP8 KV cache") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 @@ -647,7 +649,6 @@ class FlashAttentionImpl(AttentionImpl): f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") self.attn_type = attn_type - self.vllm_flash_attn_version = get_flash_attn_version() def forward( self, @@ -671,13 +672,19 @@ class FlashAttentionImpl(AttentionImpl): for profiling run. attn_metadata: Metadata for attention. NOTE: It in-place updates the output tensor. + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, ( - "key/v_scale is not supported in FlashAttention.") - assert output is not None, "Output tensor must be provided." + # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. + if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16: + assert ( + layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( + "key/v_scale is only supported in FlashAttention 3 with " + "base dtype bfloat16") + attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): @@ -694,6 +701,7 @@ class FlashAttentionImpl(AttentionImpl): window_size = self.sliding_window alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes logits_soft_cap: Optional[float] = self.logits_soft_cap + fp8_attention = kv_cache_dtype.startswith("fp8") if kv_cache.numel() > 0: key_cache = kv_cache[0] @@ -729,6 +737,19 @@ class FlashAttentionImpl(AttentionImpl): layer._v_scale, ) + if fp8_attention: + kv_cache = kv_cache.view(torch.float8_e4m3fn) + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + + if fp8_attention: + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) = \ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) @@ -753,6 +774,23 @@ class FlashAttentionImpl(AttentionImpl): key = key[:num_prefill_kv_tokens] value = value[:num_prefill_kv_tokens] + if fp8_attention: + num_kv_tokens, num_kv_heads, head_size = key.shape + + key, _ = ops.scaled_fp8_quant( + key.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._k_scale) + key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) + + value, _ = ops.scaled_fp8_quant( + value.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._v_scale) + value = value.reshape( + (num_kv_tokens, num_kv_heads, head_size)) + + descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) flash_attn_varlen_func( q=query, k=key, @@ -768,13 +806,19 @@ class FlashAttentionImpl(AttentionImpl): softcap=logits_soft_cap, out=prefill_output, fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) else: # prefix-enabled attention assert attn_type == AttentionType.DECODER, ( "Only decoder-only models support prefix caching") assert prefill_meta.seq_lens is not None + assert prefill_meta.query_start_loc is not None max_seq_len = max(prefill_meta.seq_lens) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) flash_attn_varlen_func( # noqa q=query, k=key_cache, @@ -791,6 +835,9 @@ class FlashAttentionImpl(AttentionImpl): softcap=logits_soft_cap, out=prefill_output, fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) if decode_meta := attn_metadata.decode_metadata: @@ -804,6 +851,9 @@ class FlashAttentionImpl(AttentionImpl): assert attn_type == AttentionType.DECODER, ( "Only decoder-only models support max_decode_query_len > 1" ) + assert decode_meta.query_start_loc is not None + descale_shape = (decode_meta.query_start_loc.shape[0] - 1, + key.shape[1]) flash_attn_varlen_func( q=decode_query, k=key_cache, @@ -820,6 +870,9 @@ class FlashAttentionImpl(AttentionImpl): block_table=decode_meta.block_tables, out=decode_output, fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -828,6 +881,7 @@ class FlashAttentionImpl(AttentionImpl): _, block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, @@ -841,6 +895,9 @@ class FlashAttentionImpl(AttentionImpl): softcap=logits_soft_cap, out=decode_output.unsqueeze(1), fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) return output diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index ff411f75ae7f..258090d3e80e 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -203,9 +203,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionState, MLAAttentionImpl) from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, - get_flash_attn_version, is_block_tables_empty) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 4374b5422254..b4413c36b64a 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -8,13 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union import numpy as np import torch -from vllm import envs from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap -from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad logger = init_logger(__name__) @@ -585,35 +583,3 @@ def get_num_prefill_decode_query_kv_tokens( return (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) - - -def get_flash_attn_version(): - try: - from vllm.vllm_flash_attn.flash_attn_interface import ( - fa_version_unsupported_reason, is_fa_version_supported) - - # if hopper default to FA3, otherwise stick to FA2 for now - # TODO(lucas): profile FA3 on ampere to see if it makes sense to - # use FA3 as default for both - if current_platform.get_device_capability()[0] == 9: - fa_version = 3 if is_fa_version_supported(3) else 2 - else: - fa_version = 2 - - if envs.VLLM_FLASH_ATTN_VERSION is not None: - assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] - fa_version = envs.VLLM_FLASH_ATTN_VERSION - if (current_platform.get_device_capability()[0] == 10 - and envs.VLLM_FLASH_ATTN_VERSION == 3): - logger.warning("Cannot use FA version 3 on Blackwell platform", - "defaulting to FA version 2.") - fa_version = 2 - - if not is_fa_version_supported(fa_version): - logger.error("Cannot use FA version %d is not supported due to %s", - fa_version, fa_version_unsupported_reason(fa_version)) - - assert is_fa_version_supported(fa_version) - return fa_version - except (ImportError, AssertionError): - return None diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3cbd38dbd46a..946c07d508a3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -84,6 +84,9 @@ class Attention(nn.Module): self.calculate_kv_scales = calculate_kv_scales self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) + # FlashAttn doesn't support quantizing the kv-cache only + # but requires q to be quantized as well. + self._q_scale = torch.tensor(1.0, dtype=torch.float32) # We also keep the float32 versions of k/v_scale for attention # backends that don't support tensors (Flashinfer) @@ -153,6 +156,7 @@ class Attention(nn.Module): ).parallel_config.pipeline_parallel_size) ] + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) @@ -178,7 +182,7 @@ class Attention(nn.Module): if self.calculate_kv_scales: attn_metadata = get_forward_context().attn_metadata if attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(key, value) + self.calc_kv_scales(query, key, value) if self.use_output: output_shape = (output_shape if output_shape is not None else query.shape) @@ -225,7 +229,8 @@ class Attention(nn.Module): return torch.ops.vllm.unified_attention( query, key, value, self.layer_name) - def calc_kv_scales(self, key, value): + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) self._k_scale_float = self._k_scale.item() diff --git a/vllm/envs.py b/vllm/envs.py index b2937462ad36..56bf86267476 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -78,6 +78,7 @@ if TYPE_CHECKING: VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False + Q_SCALE_CONSTANT: int = 200 K_SCALE_CONSTANT: int = 200 V_SCALE_CONSTANT: int = 100 VLLM_SERVER_DEV_MODE: bool = False @@ -524,13 +525,17 @@ environment_variables: dict[str, Callable[[], Any]] = { # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), + + # Divisor for dynamic query scale factor calculation for FP8 KV Cache + "Q_SCALE_CONSTANT": + lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), # Divisor for dynamic key scale factor calculation for FP8 KV Cache "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), - # Divisor for dynamic value scale factor calculation for FP8 KV Cache "V_SCALE_CONSTANT": lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), + # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), diff --git a/vllm/fa_utils.py b/vllm/fa_utils.py new file mode 100644 index 000000000000..028c96b839fb --- /dev/null +++ b/vllm/fa_utils.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +from vllm import envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_flash_attn_version() -> Optional[int]: + # import here to avoid circular dependencies + from vllm.platforms import current_platform + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) + device_capability = current_platform.get_device_capability() + + assert device_capability is not None + + # 1. default version depending on platform + fa_version = 3 if (device_capability.major == 9 + and is_fa_version_supported(3)) else 2 + + # 2. override if passed by environment + if envs.VLLM_FLASH_ATTN_VERSION is not None: + assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] + fa_version = envs.VLLM_FLASH_ATTN_VERSION + + # 3. fallback for unsupported combinations + if device_capability.major == 10 and fa_version == 3: + logger.warning("Cannot use FA version 3 on Blackwell platform", + "defaulting to FA version 2.") + fa_version = 2 + + if not is_fa_version_supported(fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + fa_version, fa_version_unsupported_reason(fa_version)) + + assert is_fa_version_supported(fa_version) + return fa_version + except (ImportError, AssertionError): + return None diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 92990487885b..5d766c2c27ac 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -26,11 +26,14 @@ class BaseKVCacheMethod(QuantizeMethodBase): def create_weights(self, layer: torch.nn.Module): """ - Create "weight" (aka k_scale and v_scale) for an attention layer. + Create "weight" (aka q_scale, k_scale and v_scale) + for an attention layer. """ - # Initialize the KV cache scales to -1.0, which is an invalid value. - # If the k/v_scale appears in the checkpoint, it will be + # Initialize the Q and KV cache scales to -1.0, an invalid value. + # If the q and k/v_scales appear in the checkpoint, it will be # overwritten when loading weights. + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), @@ -75,6 +78,13 @@ class BaseKVCacheMethod(QuantizeMethodBase): raise ValueError("Only support per-tensor scaling factor " "for fp8 KV cache") + if layer.q_scale < 0.0: + logger.warning_once( + "Checkpoint does not provide a q scaling factor. " + "Setting it to k_scale. This only matters for " + "the flash-attn backend.") + layer._q_scale.copy_(k_scale) + # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) layer._v_scale.copy_(v_scale) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8a53337ebc08..dd2a9cb6161e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,6 +14,7 @@ from typing_extensions import ParamSpec # import custom ops, trigger op registration import vllm._C # noqa import vllm.envs as envs +from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.utils import import_pynvml @@ -240,15 +241,6 @@ class CudaPlatformBase(Platform): "Cannot use FlashAttention-2 backend for dtype other than " "torch.float16 or torch.bfloat16.") target_backend = _Backend.XFORMERS - elif kv_cache_dtype is not None and \ - kv_cache_dtype.startswith("fp8"): - logger.info( - "Cannot use FlashAttention-2 backend for FP8 KV cache.") - logger.warning( - "Please use FlashInfer backend with FP8 KV Cache for " - "better performance by setting environment variable " - "VLLM_ATTENTION_BACKEND=FLASHINFER") - target_backend = _Backend.XFORMERS elif block_size % 16 != 0: logger.info( "Cannot use FlashAttention-2 backend for block size not " @@ -270,6 +262,17 @@ class CudaPlatformBase(Platform): "Cannot use FlashAttention-2 backend for head size %d.", head_size) target_backend = _Backend.XFORMERS + fp8_kv_cache = (kv_cache_dtype is not None + and kv_cache_dtype.startswith("fp8")) + if (fp8_kv_cache and get_flash_attn_version() != 3): + logger.info( + "Cannot use FlashAttention-2 backend for FP8 KV cache." + ) + logger.warning( + "Please use FlashInfer backend with FP8 KV Cache for " + "better performance by setting environment variable " + "VLLM_ATTENTION_BACKEND=FLASHINFER") + target_backend = _Backend.XFORMERS except ImportError: logger.info( "Cannot use FlashAttention-2 backend because the " diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ad44f256a7b9..637c01556ac1 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -6,11 +6,12 @@ from typing import TYPE_CHECKING, Any, Optional import numpy as np import torch +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) -from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv @@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values """ assert output is not None, "Output tensor must be provided." @@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl): layer._k_scale, layer._v_scale, ) + descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, + key.shape[1]) + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) # Compute attention and update output up to `num_actual_tokens`. if not attn_metadata.use_cascade: @@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl): block_table=attn_metadata.block_table, softcap=self.logits_soft_cap, fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) return output @@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl): block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale, + k_descale=layer._k_scale, + v_descale=layer._v_scale, ) return output @@ -391,6 +412,9 @@ def cascade_attention( block_table: torch.Tensor, common_prefix_len: int, fa_version: int, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") # TODO: Support sliding window. @@ -402,6 +426,7 @@ def cascade_attention( assert common_prefix_len % block_size == 0 num_common_kv_blocks = common_prefix_len // block_size assert num_common_kv_blocks > 0 + descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) # Process shared prefix. prefix_output, prefix_lse = flash_attn_varlen_func( @@ -419,8 +444,16 @@ def cascade_attention( softcap=logits_soft_cap, return_softmax_lse=True, fa_version=fa_version, + q_descale=q_descale.expand(descale_shape) + if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) + if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) + if v_descale is not None else None, ) + descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) + # Process suffix per query. suffix_output, suffix_lse = flash_attn_varlen_func( q=query, @@ -437,6 +470,12 @@ def cascade_attention( softcap=logits_soft_cap, return_softmax_lse=True, fa_version=fa_version, + q_descale=q_descale.expand(descale_shape) + if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) + if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) + if v_descale is not None else None, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index b2cbba518036..21e7d26506d3 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -5,6 +5,7 @@ import pickle import signal import sys import time +import traceback import weakref from dataclasses import dataclass from enum import Enum, auto @@ -370,6 +371,9 @@ class WorkerProc: func = partial(cloudpickle.loads(method), self.worker) output = func(*args, **kwargs) except Exception as e: + # Notes have been introduced in python 3.11 + if hasattr(e, "add_note"): + e.add_note(traceback.format_exc()) self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.FAILURE, e)) logger.exception("WorkerProc hit an exception: %s", exc_info=e) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 657333c6d84c..7faf666dc61c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1558,7 +1558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=attn_module.dtype, + dtype=self.kv_cache_dtype, use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):