mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 01:55:21 +08:00
[Core][Kernels] Enable FP8 KV Cache with Flashinfer backend. + BugFix for kv_cache_dtype=auto (#7985)
Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
3f60f2244e
commit
6b3421567d
@ -73,11 +73,14 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
def test_flashinfer_decode_with_paged_kv(
|
||||||
num_heads: Tuple[int,
|
kv_lens: List[int],
|
||||||
int], head_size: int,
|
num_heads: Tuple[int, int],
|
||||||
dtype: torch.dtype, block_size: int,
|
head_size: int,
|
||||||
soft_cap: Optional[float]) -> None:
|
dtype: torch.dtype,
|
||||||
|
block_size: int,
|
||||||
|
soft_cap: Optional[float],
|
||||||
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
num_seqs = len(kv_lens)
|
num_seqs = len(kv_lens)
|
||||||
@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
|||||||
scale = head_size**-0.5
|
scale = head_size**-0.5
|
||||||
|
|
||||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||||
|
|
||||||
key_value_cache = torch.randn(NUM_BLOCKS,
|
key_value_cache = torch.randn(NUM_BLOCKS,
|
||||||
2,
|
2,
|
||||||
block_size,
|
block_size,
|
||||||
@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
|||||||
wrapper = flashinfer.\
|
wrapper = flashinfer.\
|
||||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||||
use_tensor_cores=(
|
use_tensor_cores=(
|
||||||
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
|
(num_query_heads//num_kv_heads) > 4)
|
||||||
)
|
)
|
||||||
wrapper.begin_forward(kv_indptr,
|
wrapper.begin_forward(kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
|||||||
soft_cap=soft_cap)
|
soft_cap=soft_cap)
|
||||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
|
||||||
|
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
|
def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||||
|
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int],
|
||||||
|
head_size: int, dtype: torch.dtype, block_size: int,
|
||||||
|
soft_cap: Optional[float]) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
query_lens = [x[0] for x in seq_lens]
|
||||||
|
kv_lens = [x[1] for x in seq_lens]
|
||||||
|
num_query_heads = num_heads[0]
|
||||||
|
num_kv_heads = num_heads[1]
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
max_kv_len = max(kv_lens)
|
||||||
|
scale = head_size**-0.5
|
||||||
|
|
||||||
|
kv_cache_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
query = torch.randn(sum(query_lens),
|
||||||
|
num_query_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
NUM_BLOCKS_FP8 = 2048
|
||||||
|
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
|
||||||
|
2,
|
||||||
|
block_size,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
|
||||||
|
key_cache /= head_size**0.5
|
||||||
|
value_cache /= head_size**0.5
|
||||||
|
|
||||||
|
k_scale = key_cache.amax().item() / 448.0
|
||||||
|
v_scale = value_cache.amax().item() / 448.0
|
||||||
|
|
||||||
|
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale],
|
||||||
|
dim=1).to(kv_cache_dtype)
|
||||||
|
|
||||||
|
assert (kv_cache_fp8.shape == key_value_cache.shape)
|
||||||
|
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||||
|
block_tables = torch.randint(0,
|
||||||
|
NUM_BLOCKS_FP8,
|
||||||
|
(num_seqs, max_num_blocks_per_seq),
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
qo_indptr = [0]
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
seq_len = kv_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % block_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = block_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
qo_indptr.append(qo_indptr[-1] + query_lens[i])
|
||||||
|
|
||||||
|
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, "NHD")
|
||||||
|
wrapper.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_query_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = wrapper.forward(query,
|
||||||
|
kv_cache_fp8,
|
||||||
|
logits_soft_cap=soft_cap,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale)
|
||||||
|
|
||||||
|
ref_output = ref_paged_attn(query=query,
|
||||||
|
key_cache=key_cache.squeeze(1),
|
||||||
|
value_cache=value_cache.squeeze(1),
|
||||||
|
query_lens=query_lens,
|
||||||
|
kv_lens=kv_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
scale=scale,
|
||||||
|
soft_cap=soft_cap)
|
||||||
|
del query
|
||||||
|
del block_tables
|
||||||
|
# verify prefill fp8
|
||||||
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||||
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||||
|
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
|
@torch.inference_mode
|
||||||
|
def test_flashinfer_decode_with_paged_fp8_kv(
|
||||||
|
kv_lens: List[int],
|
||||||
|
num_heads: Tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
block_size: int,
|
||||||
|
soft_cap: Optional[float],
|
||||||
|
) -> None:
|
||||||
|
# test doesn't work for num_heads = (16,16)
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
num_seqs = len(kv_lens)
|
||||||
|
num_query_heads = num_heads[0]
|
||||||
|
num_kv_heads = num_heads[1]
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
max_kv_len = max(kv_lens)
|
||||||
|
scale = head_size**-0.5
|
||||||
|
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
|
||||||
|
kv_cache_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||||
|
NUM_BLOCKS_FP8 = 2048
|
||||||
|
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
|
||||||
|
2,
|
||||||
|
block_size,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
|
||||||
|
key_cache /= head_size**0.5
|
||||||
|
value_cache /= head_size**0.5
|
||||||
|
|
||||||
|
k_scale = key_cache.amax().item() / 448.0
|
||||||
|
v_scale = value_cache.amax().item() / 448.0
|
||||||
|
|
||||||
|
key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype)
|
||||||
|
value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype)
|
||||||
|
assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1)
|
||||||
|
kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1)
|
||||||
|
|
||||||
|
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||||
|
block_tables = torch.randint(0,
|
||||||
|
NUM_BLOCKS_FP8,
|
||||||
|
(num_seqs, max_num_blocks_per_seq),
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
seq_len = kv_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % block_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = block_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
wrapper = flashinfer.\
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||||
|
use_tensor_cores=use_tensor_cores)
|
||||||
|
wrapper.begin_forward(kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_query_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
"NONE",
|
||||||
|
data_type=dtype)
|
||||||
|
output = wrapper.forward(query,
|
||||||
|
kv_cache_fp8,
|
||||||
|
logits_soft_cap=soft_cap,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale)
|
||||||
|
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||||
|
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||||
|
|
||||||
|
ref_output = ref_paged_attn(query=query,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
query_lens=[1] * num_seqs,
|
||||||
|
kv_lens=kv_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
scale=scale,
|
||||||
|
soft_cap=soft_cap)
|
||||||
|
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
|
||||||
|
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||||
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|||||||
@ -83,6 +83,15 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
def get_supported_head_sizes() -> List[int]:
|
def get_supported_head_sizes() -> List[int]:
|
||||||
return [64, 128, 256]
|
return [64, 128, 256]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
return torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||||
|
|
||||||
|
|
||||||
class FlashInferState(AttentionState):
|
class FlashInferState(AttentionState):
|
||||||
|
|
||||||
@ -177,9 +186,9 @@ class FlashInferState(AttentionState):
|
|||||||
self._graph_decode_workspace_buffer, _indptr_buffer,
|
self._graph_decode_workspace_buffer, _indptr_buffer,
|
||||||
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
||||||
use_tensor_cores)
|
use_tensor_cores)
|
||||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
|
||||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
|
||||||
|
|
||||||
|
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
|
self.runner.kv_cache_dtype)
|
||||||
paged_kv_indptr_tensor_host = torch.arange(0,
|
paged_kv_indptr_tensor_host = torch.arange(0,
|
||||||
batch_size + 1,
|
batch_size + 1,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
@ -340,7 +349,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.page_size,
|
self.page_size,
|
||||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||||
pos_encoding_mode="NONE",
|
pos_encoding_mode="NONE",
|
||||||
data_type=self.data_type)
|
)
|
||||||
|
|
||||||
def asdict_zerocopy(self,
|
def asdict_zerocopy(self,
|
||||||
skip_fields: Optional[Set[str]] = None
|
skip_fields: Optional[Set[str]] = None
|
||||||
@ -366,7 +375,8 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||||
# Currently chunked prefill is not supported
|
# Currently chunked prefill is not supported
|
||||||
if self.num_prefills > 0:
|
if self.num_prefills > 0:
|
||||||
assert self.num_decode_tokens == 0
|
assert self.num_decode_tokens == 0, (
|
||||||
|
"Chunked prefill is not supported with flashinfer yet.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self
|
return self
|
||||||
@ -578,6 +588,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||||
|
|
||||||
return FlashInferMetadata(
|
return FlashInferMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
@ -661,7 +672,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
if attn_metadata.num_decode_tokens > 0:
|
if attn_metadata.num_decode_tokens > 0:
|
||||||
assert attn_metadata.num_prefill_tokens == 0, (
|
assert attn_metadata.num_prefill_tokens == 0, (
|
||||||
"Chunked prefill is not supported with flashinfer yet.")
|
"Chunked prefill is not supported with flashinfer yet.")
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
# Use the same reshape and cache kernel as flash attention.
|
# Use the same reshape and cache kernel as flash attention.
|
||||||
ops.reshape_and_cache_flash(
|
ops.reshape_and_cache_flash(
|
||||||
@ -674,6 +684,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
)
|
)
|
||||||
|
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||||
|
# to process the cache when the kv_cache_dtype is fp8
|
||||||
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
|
self.kv_cache_dtype)
|
||||||
|
kv_cache = kv_cache.view(torch_dtype)
|
||||||
|
|
||||||
query = query.contiguous(
|
query = query.contiguous(
|
||||||
) # Flashinfer requires query to be contiguous
|
) # Flashinfer requires query to be contiguous
|
||||||
@ -711,5 +727,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
query,
|
query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
sm_scale=self.scale,
|
sm_scale=self.scale,
|
||||||
logits_soft_cap=self.logits_soft_cap)
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale)
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|||||||
@ -226,6 +226,10 @@ def which_attn_to_use(
|
|||||||
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||||
|
logger.warning(
|
||||||
|
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||||
|
"better performance by setting environment variable "
|
||||||
|
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||||
selected_backend = _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
elif block_size % 16 != 0:
|
elif block_size % 16 != 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user