mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:15:01 +08:00
[CPU] Support FP8 KV cache (#14741)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
877e352262
commit
a2ae496589
@ -3,6 +3,12 @@
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
#if defined(__x86_64__)
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
|
||||
#else
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||
@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
|
||||
value_stride, num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
});
|
||||
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
|
||||
num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
|
||||
@ -16,9 +16,18 @@ namespace vec_op {
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
|
||||
@ -189,7 +189,7 @@ vLLM CPU backend supports the following vLLM features:
|
||||
- Model Quantization (`INT8 W8A8, AWQ, GPTQ`)
|
||||
- Chunked-prefill
|
||||
- Prefix-caching
|
||||
- FP8-E5M2 KV-Caching (TODO)
|
||||
- FP8-E5M2 KV cache
|
||||
|
||||
## Related runtime environment variables
|
||||
|
||||
|
||||
@ -266,7 +266,7 @@ def test_with_prefix_caching(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@ -303,7 +303,7 @@ def test_models_cpu(
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("chunk_size", [30, 32])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||
def test_with_prefix_caching_cpu(
|
||||
|
||||
@ -11,6 +11,7 @@ import pytest
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
@ -93,3 +94,63 @@ def test_models(
|
||||
name_0="fp16_kv_cache",
|
||||
name_1="fp8_kv_cache",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(),
|
||||
reason="test for the CPU backend.")
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype,base_model,test_model",
|
||||
[
|
||||
# Test BF16 checkpoint w. fp8_e5m2 kv-cache.
|
||||
("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"meta-llama/Llama-3.2-1B-Instruct"),
|
||||
])
|
||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
# Due to low-precision numerical divergence, this test is too sensitive for
|
||||
# the async postprocessor
|
||||
@pytest.mark.parametrize("disable_async_output_proc", [True])
|
||||
def test_cpu_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
kv_cache_dtype: str,
|
||||
base_model: str,
|
||||
test_model: str,
|
||||
max_tokens: int,
|
||||
disable_async_output_proc: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Only checks log probs match to cover the discrepancy in
|
||||
numerical sensitive kernels.
|
||||
"""
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
with vllm_runner(
|
||||
base_model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
dtype="bfloat16",
|
||||
kv_cache_dtype="auto",
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
) as vllm_model:
|
||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
with vllm_runner(
|
||||
test_model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
dtype="bfloat16",
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
) as vllm_model:
|
||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=test_outputs,
|
||||
name_0="bf16_kv_cache",
|
||||
name_1="fp8_kv_cache",
|
||||
)
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
is_quantized_kv_cache)
|
||||
# yapf: enable
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
@ -431,10 +431,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
|
||||
if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
|
||||
raise NotImplementedError(
|
||||
"Torch SDPA backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
"Torch SDPA backend FP8 KV cache requires "
|
||||
"intel_extension_for_pytorch support.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
|
||||
@ -60,9 +60,6 @@ class CpuPlatform(Platform):
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
if not model_config.enforce_eager:
|
||||
logger.warning(
|
||||
"CUDA graph is not supported on CPU, fallback to the eager "
|
||||
"mode.")
|
||||
model_config.enforce_eager = True
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -70,6 +67,25 @@ class CpuPlatform(Platform):
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if ((scheduler_config.chunked_prefill_enabled
|
||||
or cache_config.enable_prefix_caching)
|
||||
and cache_config.cache_dtype != "auto"):
|
||||
raise RuntimeError("Chunked-prefill and prefix-cache on the CPU "
|
||||
"backend is not compatible with FP8 KV cache.")
|
||||
|
||||
if cache_config.cache_dtype == "fp8_e4m3":
|
||||
cache_config.cache_dtype = "fp8_e5m2"
|
||||
logger.warning(
|
||||
"CPU backend doesn't support fp8_e4m3 KV cache type, "
|
||||
"cast to fp8_e5m2.")
|
||||
|
||||
if (cache_config.cache_dtype != "auto"
|
||||
and model_config.dtype == torch.half):
|
||||
logger.warning("FP8 KV cache on the CPU backend only does not"
|
||||
" support fp16 for now, cast to bf16.")
|
||||
model_config.dtype = torch.bfloat16
|
||||
|
||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||
|
||||
if kv_cache_space >= 0:
|
||||
@ -85,14 +101,6 @@ class CpuPlatform(Platform):
|
||||
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
|
||||
f" {kv_cache_space}, expect a positive integer value.")
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if ((scheduler_config.chunked_prefill_enabled
|
||||
or cache_config.enable_prefix_caching)
|
||||
and model_config.dtype == torch.half):
|
||||
logger.warning("Chunked-prefill on the CPU backend only does not"
|
||||
" support fp16 for now, cast to bf16.")
|
||||
model_config.dtype = torch.bfloat16
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if (parallel_config.distributed_executor_backend is not None
|
||||
and parallel_config.distributed_executor_backend != "mp"):
|
||||
|
||||
@ -53,8 +53,11 @@ class CPUCacheEngine:
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.dtype = model_config.dtype
|
||||
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
||||
self.dtype = torch.float8_e5m2
|
||||
else:
|
||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
raise NotImplementedError(f"Unsupported KV cache type "
|
||||
f"{cache_config.cache_dtype}.")
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user