mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:55:01 +08:00
[Core/Bugfix] Add FP8 K/V Scale and dtype conversion for prefix/prefill Triton Kernel (#7208)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
4ddc4743d7
commit
a046f86397
@ -45,5 +45,3 @@ Here is an example of how to enable this feature:
|
|||||||
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
|
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
|
||||||
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
|
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
|
||||||
|
|
||||||
Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.
|
|
||||||
|
|
||||||
|
|||||||
@ -32,5 +32,3 @@ Here is an example of how to enable this feature:
|
|||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.
|
|
||||||
|
|
||||||
|
|||||||
@ -6,14 +6,27 @@ prefill requests are chunked.
|
|||||||
|
|
||||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..models.utils import check_outputs_equal
|
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"meta-llama/Llama-2-7b-hf",
|
||||||
]
|
]
|
||||||
|
E5M2_KV_MODELS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
]
|
||||||
|
E4M3_KV_MODELS = [
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
|
||||||
|
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
||||||
|
]
|
||||||
|
KV_CACHE_QUANTIZATION_PATHS = {
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf":
|
||||||
|
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@ -35,11 +48,11 @@ def test_models(
|
|||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
"""
|
||||||
enable_chunked_prefill = False
|
Checks exact match decode between huggingface model and vllm runner with
|
||||||
max_num_batched_tokens = None
|
chunked prefill.
|
||||||
if chunked_prefill_token_size != -1:
|
"""
|
||||||
enable_chunked_prefill = True
|
max_num_seqs = chunked_prefill_token_size
|
||||||
max_num_batched_tokens = chunked_prefill_token_size
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
@ -49,7 +62,7 @@ def test_models(
|
|||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
enable_chunked_prefill=True,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
@ -62,3 +75,78 @@ def test_models(
|
|||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype,model",
|
||||||
|
[("fp8_e5m2", m)
|
||||||
|
for m in E5M2_KV_MODELS] + [("fp8_e4m3", m)
|
||||||
|
for m in E4M3_KV_MODELS])
|
||||||
|
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||||
|
@pytest.mark.parametrize("max_tokens", [4])
|
||||||
|
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||||
|
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||||
|
# reset distributed env properly. Use a value > 1 just when you test.
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||||
|
def test_models_with_fp8_kv_cache(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
chunked_prefill_token_size: int,
|
||||||
|
enforce_eager: bool,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Only checks log probs match between chunked-prefill and
|
||||||
|
non-chunked-prefill version of vLLM model runner.
|
||||||
|
|
||||||
|
This test is used when there is discrepancy in kernels
|
||||||
|
/ numerics (e.g. when using lower-precision types like FP8).
|
||||||
|
"""
|
||||||
|
NUM_LOG_PROBS = 8
|
||||||
|
|
||||||
|
if model == "facebook/opt-125m":
|
||||||
|
pytest.skip(
|
||||||
|
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
|
||||||
|
)
|
||||||
|
|
||||||
|
max_num_seqs = chunked_prefill_token_size
|
||||||
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
|
extra_kwargs = {}
|
||||||
|
if model in KV_CACHE_QUANTIZATION_PATHS:
|
||||||
|
extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[
|
||||||
|
model]
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
**extra_kwargs,
|
||||||
|
) as vllm_model:
|
||||||
|
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
**extra_kwargs,
|
||||||
|
) as vllm_model:
|
||||||
|
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=no_chunked_prefill_outputs,
|
||||||
|
outputs_1_lst=chunked_prefill_outputs,
|
||||||
|
name_0="no_chunked_prefill",
|
||||||
|
name_1="chunked_prefill",
|
||||||
|
)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
|||||||
|
|
||||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
NUM_HEADS = [64]
|
NUM_HEADS = [64]
|
||||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||||
@ -18,12 +19,14 @@ CUDA_DEVICES = [
|
|||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
]
|
]
|
||||||
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
||||||
|
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -33,6 +36,7 @@ def test_contexted_kv_attention(
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
sliding_window: int,
|
sliding_window: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
@ -67,16 +71,20 @@ def test_contexted_kv_attention(
|
|||||||
kv.uniform_(-1e-3, 1e-3)
|
kv.uniform_(-1e-3, 1e-3)
|
||||||
key, value = kv.unbind(dim=1)
|
key, value = kv.unbind(dim=1)
|
||||||
|
|
||||||
|
if kv_cache_dtype == "auto":
|
||||||
|
cache_dtype = dtype
|
||||||
|
else:
|
||||||
|
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||||
k_cache = torch.zeros(cache_size,
|
k_cache = torch.zeros(cache_size,
|
||||||
block_size,
|
block_size,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype)
|
dtype=cache_dtype)
|
||||||
v_cache = torch.zeros(cache_size,
|
v_cache = torch.zeros(cache_size,
|
||||||
block_size,
|
block_size,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype)
|
dtype=cache_dtype)
|
||||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||||
@ -132,6 +140,7 @@ def test_contexted_kv_attention(
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
block_table,
|
block_table,
|
||||||
@ -146,6 +155,7 @@ def test_contexted_kv_attention(
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
block_table,
|
block_table,
|
||||||
@ -208,13 +218,15 @@ def test_contexted_kv_attention(
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
output_ref = output_ref.reshape(output.shape)
|
output_ref = output_ref.reshape(output.shape)
|
||||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||||
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_contexted_kv_attention_alibi(
|
def test_contexted_kv_attention_alibi(
|
||||||
@ -222,6 +234,7 @@ def test_contexted_kv_attention_alibi(
|
|||||||
num_queries_per_kv: int,
|
num_queries_per_kv: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
@ -282,17 +295,20 @@ def test_contexted_kv_attention_alibi(
|
|||||||
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||||
kv.uniform_(-1e-3, 1e-3)
|
kv.uniform_(-1e-3, 1e-3)
|
||||||
key, value = kv.unbind(dim=1)
|
key, value = kv.unbind(dim=1)
|
||||||
|
if kv_cache_dtype == "auto":
|
||||||
|
cache_dtype = dtype
|
||||||
|
else:
|
||||||
|
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||||
k_cache = torch.zeros(cache_size,
|
k_cache = torch.zeros(cache_size,
|
||||||
block_size,
|
block_size,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype)
|
dtype=cache_dtype)
|
||||||
v_cache = torch.zeros(cache_size,
|
v_cache = torch.zeros(cache_size,
|
||||||
block_size,
|
block_size,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype)
|
dtype=cache_dtype)
|
||||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||||
@ -348,6 +364,7 @@ def test_contexted_kv_attention_alibi(
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
block_table,
|
block_table,
|
||||||
@ -362,6 +379,7 @@ def test_contexted_kv_attention_alibi(
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
block_table,
|
block_table,
|
||||||
@ -447,4 +465,5 @@ def test_contexted_kv_attention_alibi(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||||
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||||
|
|||||||
@ -459,6 +459,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
self.kv_cache_dtype,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
prefill_meta.block_tables,
|
prefill_meta.block_tables,
|
||||||
@ -468,6 +469,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.sliding_window[0],
|
self.sliding_window[0],
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
|||||||
@ -604,6 +604,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
self.kv_cache_dtype,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
prefill_meta.block_tables,
|
prefill_meta.block_tables,
|
||||||
@ -613,6 +614,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.sliding_window,
|
self.sliding_window,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
assert output[:num_prefill_tokens].shape == out.shape
|
||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_tokens] = out
|
||||||
|
|||||||
@ -90,6 +90,7 @@ class PagedAttention:
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
|
|||||||
@ -194,6 +194,7 @@ class PagedAttention:
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
@ -203,6 +204,8 @@ class PagedAttention:
|
|||||||
max_query_len: int,
|
max_query_len: int,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
context_attention_fwd(
|
context_attention_fwd(
|
||||||
@ -210,6 +213,7 @@ class PagedAttention:
|
|||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
output,
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
@ -218,6 +222,8 @@ class PagedAttention:
|
|||||||
seq_lens_tensor,
|
seq_lens_tensor,
|
||||||
context_lens,
|
context_lens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
sliding_window,
|
sliding_window,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,6 +18,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
V_cache,
|
V_cache,
|
||||||
B_Loc,
|
B_Loc,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
B_Start_Loc,
|
B_Start_Loc,
|
||||||
B_Seqlen,
|
B_Seqlen,
|
||||||
B_Ctxlen,
|
B_Ctxlen,
|
||||||
@ -117,11 +119,16 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_kv_head * stride_v_cache_h +
|
cur_kv_head * stride_v_cache_h +
|
||||||
offs_d[None, :] * stride_v_cache_d +
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||||
k = tl.load(K_cache + off_k,
|
k_load = tl.load(K_cache + off_k,
|
||||||
mask=dim_mask[:, None] &
|
mask=dim_mask[:, None] &
|
||||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||||
other=0.0) # [D,N]
|
other=0.0) # [D,N]
|
||||||
|
|
||||||
|
if k_load.dtype.is_fp8():
|
||||||
|
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
|
||||||
|
else:
|
||||||
|
k = k_load
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
@ -161,12 +168,16 @@ if triton.__version__ >= "2.1.0":
|
|||||||
acc_scale = l_i / l_i_new * alpha
|
acc_scale = l_i / l_i_new * alpha
|
||||||
acc = acc * acc_scale[:, None]
|
acc = acc * acc_scale[:, None]
|
||||||
# update acc
|
# update acc
|
||||||
v = tl.load(V_cache + off_v,
|
v_load = tl.load(V_cache + off_v,
|
||||||
mask=dim_mask[None, :] &
|
mask=dim_mask[None, :] &
|
||||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||||
other=0.0) # [N,D]
|
other=0.0) # [N,D]
|
||||||
|
if v_load.dtype.is_fp8():
|
||||||
|
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
|
||||||
|
else:
|
||||||
|
v = v_load
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
acc += tl.dot(p, v)
|
acc += tl.dot(p, v)
|
||||||
# # update m_i and l_i
|
# # update m_i and l_i
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
@ -225,8 +236,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
mask=dim_mask[None, :] &
|
mask=dim_mask[None, :] &
|
||||||
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
acc += tl.dot(p, v)
|
acc += tl.dot(p, v)
|
||||||
# update m_i and l_i
|
# update m_i and l_i
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
@ -336,7 +347,6 @@ if triton.__version__ >= "2.1.0":
|
|||||||
k = tl.load(K_cache + off_k,
|
k = tl.load(K_cache + off_k,
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
@ -442,6 +452,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
V_cache,
|
V_cache,
|
||||||
B_Loc,
|
B_Loc,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
B_Start_Loc,
|
B_Start_Loc,
|
||||||
B_Seqlen,
|
B_Seqlen,
|
||||||
B_Ctxlen,
|
B_Ctxlen,
|
||||||
@ -537,11 +549,16 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_kv_head * stride_v_cache_h +
|
cur_kv_head * stride_v_cache_h +
|
||||||
offs_d[None, :] * stride_v_cache_d +
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||||
k = tl.load(K_cache + off_k,
|
k_load = tl.load(K_cache + off_k,
|
||||||
mask=dim_mask[:, None] &
|
mask=dim_mask[:, None] &
|
||||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||||
other=0.0) # [D,N]
|
other=0.0) # [D,N]
|
||||||
|
|
||||||
|
if k_load.dtype.is_fp8():
|
||||||
|
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
|
||||||
|
else:
|
||||||
|
k = k_load
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
@ -573,12 +590,16 @@ if triton.__version__ >= "2.1.0":
|
|||||||
# acc_scale = l_i / l_i_new * alpha
|
# acc_scale = l_i / l_i_new * alpha
|
||||||
acc = acc * acc_scale[:, None]
|
acc = acc * acc_scale[:, None]
|
||||||
# update acc
|
# update acc
|
||||||
v = tl.load(V_cache + off_v,
|
v_load = tl.load(V_cache + off_v,
|
||||||
mask=dim_mask[None, :] &
|
mask=dim_mask[None, :] &
|
||||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
if v_load.dtype.is_fp8():
|
||||||
|
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
|
||||||
|
else:
|
||||||
|
v = v_load
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
acc += tl.dot(p, v, allow_tf32=False)
|
acc += tl.dot(p, v, allow_tf32=False)
|
||||||
# update m_i and l_i
|
# update m_i and l_i
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
@ -650,8 +671,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
((start_n + offs_n[:, None]) <
|
((start_n + offs_n[:, None]) <
|
||||||
cur_batch_seq_len - cur_batch_ctx_len),
|
cur_batch_seq_len - cur_batch_ctx_len),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
acc += tl.dot(p, v, allow_tf32=False)
|
acc += tl.dot(p, v, allow_tf32=False)
|
||||||
# update m_i and l_i
|
# update m_i and l_i
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
@ -675,6 +696,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
o,
|
o,
|
||||||
|
kv_cache_dtype: str,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
b_loc,
|
b_loc,
|
||||||
@ -682,17 +704,41 @@ if triton.__version__ >= "2.1.0":
|
|||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
b_ctx_len,
|
||||||
max_input_len,
|
max_input_len,
|
||||||
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None):
|
sliding_window=None):
|
||||||
|
|
||||||
cap = current_platform.get_device_capability()
|
cap = current_platform.get_device_capability()
|
||||||
BLOCK = 128 if cap[0] >= 8 else 64
|
BLOCK = 128 if cap[0] >= 8 else 64
|
||||||
|
NUM_WARPS = 8
|
||||||
|
|
||||||
# need to reduce num. blocks when using fp32
|
# need to reduce num. blocks when using fp32
|
||||||
# due to increased use of GPU shared memory
|
# due to increased use of GPU shared memory
|
||||||
if q.dtype is torch.float32:
|
if q.dtype is torch.float32:
|
||||||
BLOCK = BLOCK // 2
|
BLOCK = BLOCK // 2
|
||||||
|
|
||||||
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
assert (k_cache.dtype == torch.uint8)
|
||||||
|
assert (v_cache.dtype == torch.uint8)
|
||||||
|
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
target_dtype = torch.float8_e4m3fn
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
target_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||||
|
|
||||||
|
k_cache = k_cache.view(target_dtype)
|
||||||
|
v_cache = v_cache.view(target_dtype)
|
||||||
|
|
||||||
|
if (k_cache.dtype == torch.uint8
|
||||||
|
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
|
||||||
|
raise ValueError("kv_cache_dtype='auto' unsupported for\
|
||||||
|
FP8 KV Cache prefill kernel")
|
||||||
|
|
||||||
# shape constraints
|
# shape constraints
|
||||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
assert Lq == Lk and Lk == Lv
|
assert Lq == Lk and Lk == Lv
|
||||||
@ -709,7 +755,6 @@ if triton.__version__ >= "2.1.0":
|
|||||||
if sliding_window is None or sliding_window <= 0:
|
if sliding_window is None or sliding_window <= 0:
|
||||||
sliding_window = 0
|
sliding_window = 0
|
||||||
|
|
||||||
num_warps = 8 if Lk <= 64 else 8
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
_fwd_kernel_alibi[grid](
|
_fwd_kernel_alibi[grid](
|
||||||
q,
|
q,
|
||||||
@ -719,6 +764,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_cache,
|
v_cache,
|
||||||
b_loc,
|
b_loc,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
b_ctx_len,
|
||||||
@ -757,7 +804,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
num_warps=num_warps,
|
num_warps=NUM_WARPS,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -770,6 +817,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_cache,
|
v_cache,
|
||||||
b_loc,
|
b_loc,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
b_ctx_len,
|
||||||
@ -807,7 +856,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
SLIDING_WINDOW=sliding_window,
|
SLIDING_WINDOW=sliding_window,
|
||||||
num_warps=num_warps,
|
num_warps=NUM_WARPS,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -545,10 +545,6 @@ class CacheConfig:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Prefix caching is not supported with sliding window. "
|
"Prefix caching is not supported with sliding window. "
|
||||||
"Run with --disable-sliding-window to use prefix caching.")
|
"Run with --disable-sliding-window to use prefix caching.")
|
||||||
if self.cache_dtype == "fp8":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Prefix caching is not supported for fp8 cache_dtype. "
|
|
||||||
"Run with --kv-cache-dtype auto to use prefix caching.")
|
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user