mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:15:54 +08:00
[CPU] Support FP8 KV cache (#14741)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
877e352262
commit
a2ae496589
@ -3,6 +3,12 @@
|
|||||||
|
|
||||||
#include "cpu_types.hpp"
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
|
#if defined(__x86_64__)
|
||||||
|
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
|
||||||
|
#else
|
||||||
|
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||||
@ -95,8 +101,7 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int element_num_per_block = key_caches[0][0].numel();
|
const int element_num_per_block = key_caches[0][0].numel();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
|
||||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||||
element_num_per_block, num_layers);
|
element_num_per_block, num_layers);
|
||||||
@ -118,14 +123,13 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
|||||||
int key_stride = key.stride(0);
|
int key_stride = key.stride(0);
|
||||||
int value_stride = value.stride(0);
|
int value_stride = value.stride(0);
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||||
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
|
||||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||||
reshape_and_cache_cpu_impl<scalar_t>(
|
reshape_and_cache_cpu_impl<scalar_t>(
|
||||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
|
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
|
||||||
value_stride, num_heads, head_size, block_size, x);
|
num_heads, head_size, block_size, x);
|
||||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,9 +16,18 @@ namespace vec_op {
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||||
|
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
|
||||||
|
|
||||||
#ifndef CPU_OP_GUARD
|
#ifndef CPU_OP_GUARD
|
||||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||||
|
|||||||
@ -189,7 +189,7 @@ vLLM CPU backend supports the following vLLM features:
|
|||||||
- Model Quantization (`INT8 W8A8, AWQ, GPTQ`)
|
- Model Quantization (`INT8 W8A8, AWQ, GPTQ`)
|
||||||
- Chunked-prefill
|
- Chunked-prefill
|
||||||
- Prefix-caching
|
- Prefix-caching
|
||||||
- FP8-E5M2 KV-Caching (TODO)
|
- FP8-E5M2 KV cache
|
||||||
|
|
||||||
## Related runtime environment variables
|
## Related runtime environment variables
|
||||||
|
|
||||||
|
|||||||
@ -266,7 +266,7 @@ def test_with_prefix_caching(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||||
@pytest.mark.parametrize("enforce_eager", [False])
|
@pytest.mark.parametrize("enforce_eager", [False])
|
||||||
@ -303,7 +303,7 @@ def test_models_cpu(
|
|||||||
@pytest.mark.parametrize("max_tokens", [16])
|
@pytest.mark.parametrize("max_tokens", [16])
|
||||||
@pytest.mark.parametrize("enforce_eager", [False])
|
@pytest.mark.parametrize("enforce_eager", [False])
|
||||||
@pytest.mark.parametrize("chunk_size", [30, 32])
|
@pytest.mark.parametrize("chunk_size", [30, 32])
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
|
||||||
@pytest.mark.cpu_model
|
@pytest.mark.cpu_model
|
||||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||||
def test_with_prefix_caching_cpu(
|
def test_with_prefix_caching_cpu(
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import pytest
|
|||||||
|
|
||||||
from tests.kernels.utils import override_backend_env_variable
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
@ -93,3 +94,63 @@ def test_models(
|
|||||||
name_0="fp16_kv_cache",
|
name_0="fp16_kv_cache",
|
||||||
name_1="fp8_kv_cache",
|
name_1="fp8_kv_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cpu_model
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cpu(),
|
||||||
|
reason="test for the CPU backend.")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"kv_cache_dtype,base_model,test_model",
|
||||||
|
[
|
||||||
|
# Test BF16 checkpoint w. fp8_e5m2 kv-cache.
|
||||||
|
("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct"),
|
||||||
|
])
|
||||||
|
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||||
|
@pytest.mark.parametrize("max_tokens", [4])
|
||||||
|
# Due to low-precision numerical divergence, this test is too sensitive for
|
||||||
|
# the async postprocessor
|
||||||
|
@pytest.mark.parametrize("disable_async_output_proc", [True])
|
||||||
|
def test_cpu_models(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
base_model: str,
|
||||||
|
test_model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
disable_async_output_proc: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Only checks log probs match to cover the discrepancy in
|
||||||
|
numerical sensitive kernels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MAX_MODEL_LEN = 1024
|
||||||
|
NUM_LOG_PROBS = 8
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
base_model,
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
dtype="bfloat16",
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
|
) as vllm_model:
|
||||||
|
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
test_model,
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
dtype="bfloat16",
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
|
) as vllm_model:
|
||||||
|
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=baseline_outputs,
|
||||||
|
outputs_1_lst=test_outputs,
|
||||||
|
name_0="bf16_kv_cache",
|
||||||
|
name_1="fp8_kv_cache",
|
||||||
|
)
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
|
||||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import make_tensor_with_pad
|
from vllm.utils import make_tensor_with_pad
|
||||||
@ -431,10 +431,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {supported_head_sizes}.")
|
f"Supported head sizes are: {supported_head_sizes}.")
|
||||||
if is_quantized_kv_cache(kv_cache_dtype):
|
|
||||||
|
if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Torch SDPA backend does not support FP8 KV cache. "
|
"Torch SDPA backend FP8 KV cache requires "
|
||||||
"Please use xFormers backend instead.")
|
"intel_extension_for_pytorch support.")
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -60,9 +60,6 @@ class CpuPlatform(Platform):
|
|||||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if not model_config.enforce_eager:
|
if not model_config.enforce_eager:
|
||||||
logger.warning(
|
|
||||||
"CUDA graph is not supported on CPU, fallback to the eager "
|
|
||||||
"mode.")
|
|
||||||
model_config.enforce_eager = True
|
model_config.enforce_eager = True
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
@ -70,6 +67,25 @@ class CpuPlatform(Platform):
|
|||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 16
|
cache_config.block_size = 16
|
||||||
|
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
if ((scheduler_config.chunked_prefill_enabled
|
||||||
|
or cache_config.enable_prefix_caching)
|
||||||
|
and cache_config.cache_dtype != "auto"):
|
||||||
|
raise RuntimeError("Chunked-prefill and prefix-cache on the CPU "
|
||||||
|
"backend is not compatible with FP8 KV cache.")
|
||||||
|
|
||||||
|
if cache_config.cache_dtype == "fp8_e4m3":
|
||||||
|
cache_config.cache_dtype = "fp8_e5m2"
|
||||||
|
logger.warning(
|
||||||
|
"CPU backend doesn't support fp8_e4m3 KV cache type, "
|
||||||
|
"cast to fp8_e5m2.")
|
||||||
|
|
||||||
|
if (cache_config.cache_dtype != "auto"
|
||||||
|
and model_config.dtype == torch.half):
|
||||||
|
logger.warning("FP8 KV cache on the CPU backend only does not"
|
||||||
|
" support fp16 for now, cast to bf16.")
|
||||||
|
model_config.dtype = torch.bfloat16
|
||||||
|
|
||||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||||
|
|
||||||
if kv_cache_space >= 0:
|
if kv_cache_space >= 0:
|
||||||
@ -85,14 +101,6 @@ class CpuPlatform(Platform):
|
|||||||
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
|
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
|
||||||
f" {kv_cache_space}, expect a positive integer value.")
|
f" {kv_cache_space}, expect a positive integer value.")
|
||||||
|
|
||||||
scheduler_config = vllm_config.scheduler_config
|
|
||||||
if ((scheduler_config.chunked_prefill_enabled
|
|
||||||
or cache_config.enable_prefix_caching)
|
|
||||||
and model_config.dtype == torch.half):
|
|
||||||
logger.warning("Chunked-prefill on the CPU backend only does not"
|
|
||||||
" support fp16 for now, cast to bf16.")
|
|
||||||
model_config.dtype = torch.bfloat16
|
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
if (parallel_config.distributed_executor_backend is not None
|
if (parallel_config.distributed_executor_backend is not None
|
||||||
and parallel_config.distributed_executor_backend != "mp"):
|
and parallel_config.distributed_executor_backend != "mp"):
|
||||||
|
|||||||
@ -53,8 +53,11 @@ class CPUCacheEngine:
|
|||||||
|
|
||||||
if cache_config.cache_dtype == "auto":
|
if cache_config.cache_dtype == "auto":
|
||||||
self.dtype = model_config.dtype
|
self.dtype = model_config.dtype
|
||||||
|
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
||||||
|
self.dtype = torch.float8_e5m2
|
||||||
else:
|
else:
|
||||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
raise NotImplementedError(f"Unsupported KV cache type "
|
||||||
|
f"{cache_config.cache_dtype}.")
|
||||||
|
|
||||||
# Get attention backend.
|
# Get attention backend.
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user