diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index e3809acad745..d726ee9307fe 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -3,6 +3,12 @@ #include "cpu_types.hpp" +#if defined(__x86_64__) + #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2 +#else + #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES +#endif + namespace { template void copy_blocks_cpu_impl(std::vector const& key_caches, @@ -95,13 +101,12 @@ void copy_blocks(std::vector const& key_caches, } const int element_num_per_block = key_caches[0][0].numel(); - VLLM_DISPATCH_FLOATING_TYPES( - key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, - element_num_per_block, num_layers); - CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) - }); + DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) + copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, + element_num_per_block, num_layers); + CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) + }); } void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, @@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, int key_stride = key.stride(0); int value_stride = value.stride(0); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) - reshape_and_cache_cpu_impl( - key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), num_tokens, key_stride, - value_stride, num_heads, head_size, block_size, x); - CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) - }); + DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) + reshape_and_cache_cpu_impl( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), num_tokens, key_stride, value_stride, + num_heads, head_size, block_size, x); + CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) + }); } void swap_blocks(torch::Tensor& src, torch::Tensor& dst, diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index a4ef2be2a58c..a9369e1fd101 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -16,9 +16,18 @@ namespace vec_op { AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) + #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__)) + #ifndef CPU_OP_GUARD #define CPU_KERNEL_GUARD_IN(NAME) #define CPU_KERNEL_GUARD_OUT(NAME) diff --git a/docs/source/getting_started/installation/cpu.md b/docs/source/getting_started/installation/cpu.md index 43c9187f072e..65af7b50bdc1 100644 --- a/docs/source/getting_started/installation/cpu.md +++ b/docs/source/getting_started/installation/cpu.md @@ -189,7 +189,7 @@ vLLM CPU backend supports the following vLLM features: - Model Quantization (`INT8 W8A8, AWQ, GPTQ`) - Chunked-prefill - Prefix-caching -- FP8-E5M2 KV-Caching (TODO) +- FP8-E5M2 KV cache ## Related runtime environment variables diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 5bf48b5cced4..be007de321c8 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -266,7 +266,7 @@ def test_with_prefix_caching( @pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) @pytest.mark.parametrize("enforce_eager", [False]) @@ -303,7 +303,7 @@ def test_models_cpu( @pytest.mark.parametrize("max_tokens", [16]) @pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("chunk_size", [30, 32]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) @pytest.mark.cpu_model @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") def test_with_prefix_caching_cpu( diff --git a/tests/models/decoder_only/language/test_fp8.py b/tests/models/decoder_only/language/test_fp8.py index 27c125160aa1..faca7a566e79 100644 --- a/tests/models/decoder_only/language/test_fp8.py +++ b/tests/models/decoder_only/language/test_fp8.py @@ -11,6 +11,7 @@ import pytest from tests.kernels.utils import override_backend_env_variable from tests.quantization.utils import is_quant_method_supported +from vllm.platforms import current_platform from ...utils import check_logprobs_close @@ -93,3 +94,63 @@ def test_models( name_0="fp16_kv_cache", name_1="fp8_kv_cache", ) + + +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), + reason="test for the CPU backend.") +@pytest.mark.parametrize( + "kv_cache_dtype,base_model,test_model", + [ + # Test BF16 checkpoint w. fp8_e5m2 kv-cache. + ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct"), + ]) +# Due to low-precision numerical divergence, we only test logprob of 4 tokens +@pytest.mark.parametrize("max_tokens", [4]) +# Due to low-precision numerical divergence, this test is too sensitive for +# the async postprocessor +@pytest.mark.parametrize("disable_async_output_proc", [True]) +def test_cpu_models( + vllm_runner, + example_prompts, + kv_cache_dtype: str, + base_model: str, + test_model: str, + max_tokens: int, + disable_async_output_proc: bool, +) -> None: + """ + Only checks log probs match to cover the discrepancy in + numerical sensitive kernels. + """ + + MAX_MODEL_LEN = 1024 + NUM_LOG_PROBS = 8 + + with vllm_runner( + base_model, + max_model_len=MAX_MODEL_LEN, + dtype="bfloat16", + kv_cache_dtype="auto", + disable_async_output_proc=disable_async_output_proc, + ) as vllm_model: + baseline_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS) + + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + dtype="bfloat16", + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, + ) as vllm_model: + test_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS) + + check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=test_outputs, + name_0="bf16_kv_cache", + name_1="fp8_kv_cache", + ) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 37dd75da2759..afe2acff4ab3 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -17,7 +17,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, is_quantized_kv_cache) # yapf: enable from vllm.attention.backends.utils import CommonAttentionState -from vllm.attention.ops.ipex_attn import PagedAttention +from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.logger import init_logger from vllm.utils import make_tensor_with_pad @@ -431,10 +431,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if is_quantized_kv_cache(kv_cache_dtype): + + if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: raise NotImplementedError( - "Torch SDPA backend does not support FP8 KV cache. " - "Please use xFormers backend instead.") + "Torch SDPA backend FP8 KV cache requires " + "intel_extension_for_pytorch support.") self.attn_type = attn_type def forward( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 140335dfb64a..40eacfd080e1 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -60,9 +60,6 @@ class CpuPlatform(Platform): # Reminder: Please update docs/source/features/compatibility_matrix.md # If the feature combo become valid if not model_config.enforce_eager: - logger.warning( - "CUDA graph is not supported on CPU, fallback to the eager " - "mode.") model_config.enforce_eager = True cache_config = vllm_config.cache_config @@ -70,6 +67,25 @@ class CpuPlatform(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 16 + scheduler_config = vllm_config.scheduler_config + if ((scheduler_config.chunked_prefill_enabled + or cache_config.enable_prefix_caching) + and cache_config.cache_dtype != "auto"): + raise RuntimeError("Chunked-prefill and prefix-cache on the CPU " + "backend is not compatible with FP8 KV cache.") + + if cache_config.cache_dtype == "fp8_e4m3": + cache_config.cache_dtype = "fp8_e5m2" + logger.warning( + "CPU backend doesn't support fp8_e4m3 KV cache type, " + "cast to fp8_e5m2.") + + if (cache_config.cache_dtype != "auto" + and model_config.dtype == torch.half): + logger.warning("FP8 KV cache on the CPU backend only does not" + " support fp16 for now, cast to bf16.") + model_config.dtype = torch.bfloat16 + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space >= 0: @@ -85,14 +101,6 @@ class CpuPlatform(Platform): "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" f" {kv_cache_space}, expect a positive integer value.") - scheduler_config = vllm_config.scheduler_config - if ((scheduler_config.chunked_prefill_enabled - or cache_config.enable_prefix_caching) - and model_config.dtype == torch.half): - logger.warning("Chunked-prefill on the CPU backend only does not" - " support fp16 for now, cast to bf16.") - model_config.dtype = torch.bfloat16 - parallel_config = vllm_config.parallel_config if (parallel_config.distributed_executor_backend is not None and parallel_config.distributed_executor_backend != "mp"): diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 27b1a2dd1be8..70d2924a045b 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -53,8 +53,11 @@ class CPUCacheEngine: if cache_config.cache_dtype == "auto": self.dtype = model_config.dtype + elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: + self.dtype = torch.float8_e5m2 else: - self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + raise NotImplementedError(f"Unsupported KV cache type " + f"{cache_config.cache_dtype}.") # Get attention backend. self.attn_backend = get_attn_backend(