[Core][Kernels] Enable FP8 KV Cache with Flashinfer backend. + BugFix for kv_cache_dtype=auto (#7985)

Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Pavani Majety 2024-08-29 11:53:11 -07:00 committed by GitHub
parent 3f60f2244e
commit 6b3421567d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 250 additions and 12 deletions

View File

@ -73,11 +73,14 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], def test_flashinfer_decode_with_paged_kv(
num_heads: Tuple[int, kv_lens: List[int],
int], head_size: int, num_heads: Tuple[int, int],
dtype: torch.dtype, block_size: int, head_size: int,
soft_cap: Optional[float]) -> None: dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
scale = head_size**-0.5 scale = head_size**-0.5
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS, key_value_cache = torch.randn(NUM_BLOCKS,
2, 2,
block_size, block_size,
@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
wrapper = flashinfer.\ wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=( use_tensor_cores=(
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8)) (num_query_heads//num_kv_heads) > 4)
) )
wrapper.begin_forward(kv_indptr, wrapper.begin_forward(kv_indptr,
kv_indices, kv_indices,
@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
soft_cap=soft_cap) soft_cap=soft_cap)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
def test_flashinfer_prefill_with_paged_fp8_kv(
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int],
head_size: int, dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
kv_cache_dtype = torch.float8_e4m3fn
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
NUM_BLOCKS_FP8 = 2048
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
key_cache /= head_size**0.5
value_cache /= head_size**0.5
k_scale = key_cache.amax().item() / 448.0
v_scale = value_cache.amax().item() / 448.0
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale],
dim=1).to(kv_cache_dtype)
assert (kv_cache_fp8.shape == key_value_cache.shape)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS_FP8,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
qo_indptr = [0]
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
qo_indptr.append(qo_indptr[-1] + query_lens[i])
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache.squeeze(1),
value_cache=value_cache.squeeze(1),
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
del query
del block_tables
# verify prefill fp8
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_decode_with_paged_fp8_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
# test doesn't work for num_heads = (16,16)
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
kv_cache_dtype = torch.float8_e4m3fn
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
NUM_BLOCKS_FP8 = 2048
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
key_cache /= head_size**0.5
value_cache /= head_size**0.5
k_scale = key_cache.amax().item() / 448.0
v_scale = value_cache.amax().item() / 448.0
key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype)
value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype)
assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1)
kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS_FP8,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=use_tensor_cores)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"

View File

@ -83,6 +83,15 @@ class FlashInferBackend(AttentionBackend):
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
return [64, 128, 256] return [64, 128, 256]
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
class FlashInferState(AttentionState): class FlashInferState(AttentionState):
@ -177,9 +186,9 @@ class FlashInferState(AttentionState):
self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD", self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores) use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
paged_kv_indptr_tensor_host = torch.arange(0, paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1, batch_size + 1,
dtype=torch.int32) dtype=torch.int32)
@ -340,7 +349,7 @@ class FlashInferMetadata(AttentionMetadata):
self.page_size, self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope. # Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE", pos_encoding_mode="NONE",
data_type=self.data_type) )
def asdict_zerocopy(self, def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None skip_fields: Optional[Set[str]] = None
@ -366,7 +375,8 @@ class FlashInferMetadata(AttentionMetadata):
def decode_metadata(self) -> Optional["FlashInferMetadata"]: def decode_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported # Currently chunked prefill is not supported
if self.num_prefills > 0: if self.num_prefills > 0:
assert self.num_decode_tokens == 0 assert self.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
return None return None
return self return self
@ -578,6 +588,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
kv_cache_dtype = get_kv_cache_torch_dtype( kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype) self.runner.kv_cache_dtype, self.runner.model_config.dtype)
return FlashInferMetadata( return FlashInferMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
@ -661,7 +672,6 @@ class FlashInferImpl(AttentionImpl):
if attn_metadata.num_decode_tokens > 0: if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, ( assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.") "Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None: if kv_cache is not None:
# Use the same reshape and cache kernel as flash attention. # Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash( ops.reshape_and_cache_flash(
@ -674,6 +684,12 @@ class FlashInferImpl(AttentionImpl):
k_scale, k_scale,
v_scale, v_scale,
) )
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
query = query.contiguous( query = query.contiguous(
) # Flashinfer requires query to be contiguous ) # Flashinfer requires query to be contiguous
@ -711,5 +727,7 @@ class FlashInferImpl(AttentionImpl):
query, query,
kv_cache, kv_cache,
sm_scale=self.scale, sm_scale=self.scale,
logits_soft_cap=self.logits_soft_cap) logits_soft_cap=self.logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)

View File

@ -226,6 +226,10 @@ def which_attn_to_use(
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.") "Cannot use FlashAttention-2 backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
elif block_size % 16 != 0: elif block_size % 16 != 0:
logger.info( logger.info(