mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 20:22:15 +08:00
[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
cc7ae5e7ca
commit
24d0c9e6ed
@ -9,8 +9,11 @@ from typing import Optional
|
|||||||
import flashinfer
|
import flashinfer
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
FP8_DTYPE = torch.float8_e4m3fn
|
FP8_DTYPE = torch.float8_e4m3fn
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
|
|
||||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
@ -61,13 +64,13 @@ def benchmark_decode(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||||
|
|
||||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||||
|
q_scale = 1.0
|
||||||
|
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||||
if q_quant_dtype == FP8_DTYPE:
|
if q_quant_dtype == FP8_DTYPE:
|
||||||
query, q_scale = to_float8(query)
|
query, _ = to_float8(ref_query)
|
||||||
ref_query = query.to(dtype) * q_scale
|
|
||||||
else:
|
else:
|
||||||
q_scale = 1.0
|
query = ref_query
|
||||||
ref_query = query
|
|
||||||
|
|
||||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||||
kv_lens[-1] = max_seq_len
|
kv_lens[-1] = max_seq_len
|
||||||
@ -75,14 +78,13 @@ def benchmark_decode(
|
|||||||
seq_lens = kv_lens
|
seq_lens = kv_lens
|
||||||
max_seq_len = torch.max(seq_lens).item()
|
max_seq_len = torch.max(seq_lens).item()
|
||||||
|
|
||||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||||
|
k_scale = v_scale = 1.0
|
||||||
|
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||||
if kv_quant_dtype == FP8_DTYPE:
|
if kv_quant_dtype == FP8_DTYPE:
|
||||||
kv_cache, kv_scale = to_float8(kv_cache)
|
kv_cache, _ = to_float8(ref_kv_cache)
|
||||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
|
||||||
else:
|
else:
|
||||||
kv_scale = 1.0
|
kv_cache = ref_kv_cache
|
||||||
ref_kv_cache = kv_cache
|
|
||||||
k_scale = v_scale = kv_scale
|
|
||||||
|
|
||||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
block_tables = torch.randint(
|
block_tables = torch.randint(
|
||||||
@ -142,11 +144,31 @@ def benchmark_decode(
|
|||||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||||
|
|
||||||
o_scale = 1.0
|
o_scale = 1.0
|
||||||
|
o_sf_scale = None
|
||||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
if o_quant_dtype == FP4_DTYPE:
|
||||||
|
o_sf_scale = 500.0
|
||||||
|
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||||
|
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||||
|
torch.empty(
|
||||||
|
(
|
||||||
|
round_up(query.shape[0], 128),
|
||||||
|
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||||
|
),
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||||
|
|
||||||
def baseline_decode():
|
def baseline_decode():
|
||||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
return wrapper.run(
|
||||||
|
ref_query,
|
||||||
|
ref_kv_cache,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
out=output_baseline,
|
||||||
|
)
|
||||||
|
|
||||||
def trtllm_decode():
|
def trtllm_decode():
|
||||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||||
@ -158,6 +180,7 @@ def benchmark_decode(
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||||
bmm2_scale=v_scale / o_scale,
|
bmm2_scale=v_scale / o_scale,
|
||||||
|
o_sf_scale=o_sf_scale,
|
||||||
out=output_trtllm,
|
out=output_trtllm,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -237,6 +260,7 @@ if __name__ == "__main__":
|
|||||||
(None, None, None),
|
(None, None, None),
|
||||||
(None, FP8_DTYPE, None),
|
(None, FP8_DTYPE, None),
|
||||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||||
]
|
]
|
||||||
|
|
||||||
for quant_dtype in quant_dtypes:
|
for quant_dtype in quant_dtypes:
|
||||||
|
|||||||
@ -9,8 +9,11 @@ from typing import Optional
|
|||||||
import flashinfer
|
import flashinfer
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
FP8_DTYPE = torch.float8_e4m3fn
|
FP8_DTYPE = torch.float8_e4m3fn
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
|
|
||||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
@ -72,13 +75,15 @@ def benchmark_prefill(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||||
|
q_scale = 1.0
|
||||||
|
ref_query = torch.randn(
|
||||||
|
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
|
||||||
|
)
|
||||||
if q_quant_dtype == FP8_DTYPE:
|
if q_quant_dtype == FP8_DTYPE:
|
||||||
query, q_scale = to_float8(query)
|
query, _ = to_float8(ref_query)
|
||||||
ref_query = query.to(dtype) * q_scale
|
|
||||||
else:
|
else:
|
||||||
q_scale = 1.0
|
query = ref_query
|
||||||
ref_query = query
|
|
||||||
|
|
||||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||||
kv_lens[-1] = max_kv_len
|
kv_lens[-1] = max_kv_len
|
||||||
@ -86,14 +91,13 @@ def benchmark_prefill(
|
|||||||
seq_lens = kv_lens + q_lens
|
seq_lens = kv_lens + q_lens
|
||||||
max_seq_len = torch.max(seq_lens).item()
|
max_seq_len = torch.max(seq_lens).item()
|
||||||
|
|
||||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||||
|
k_scale = v_scale = 1.0
|
||||||
|
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||||
if kv_quant_dtype == FP8_DTYPE:
|
if kv_quant_dtype == FP8_DTYPE:
|
||||||
kv_cache, kv_scale = to_float8(kv_cache)
|
kv_cache, _ = to_float8(ref_kv_cache)
|
||||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
|
||||||
else:
|
else:
|
||||||
kv_scale = 1.0
|
kv_cache = ref_kv_cache
|
||||||
ref_kv_cache = kv_cache
|
|
||||||
k_scale = v_scale = kv_scale
|
|
||||||
|
|
||||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
block_tables = torch.randint(
|
block_tables = torch.randint(
|
||||||
@ -152,11 +156,31 @@ def benchmark_prefill(
|
|||||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||||
|
|
||||||
o_scale = 1.0
|
o_scale = 1.0
|
||||||
|
o_sf_scale = None
|
||||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
if o_quant_dtype == FP4_DTYPE:
|
||||||
|
o_sf_scale = 500.0
|
||||||
|
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||||
|
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||||
|
torch.empty(
|
||||||
|
(
|
||||||
|
round_up(query.shape[0], 128),
|
||||||
|
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||||
|
),
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||||
|
|
||||||
def baseline_prefill():
|
def baseline_prefill():
|
||||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
return wrapper.run(
|
||||||
|
ref_query,
|
||||||
|
ref_kv_cache,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
out=output_baseline,
|
||||||
|
)
|
||||||
|
|
||||||
def trtllm_prefill():
|
def trtllm_prefill():
|
||||||
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||||
@ -172,6 +196,7 @@ def benchmark_prefill(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
cum_seq_lens_q=q_indptr,
|
cum_seq_lens_q=q_indptr,
|
||||||
cum_seq_lens_kv=kv_indptr,
|
cum_seq_lens_kv=kv_indptr,
|
||||||
|
o_sf_scale=o_sf_scale,
|
||||||
out=output_trtllm,
|
out=output_trtllm,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -250,6 +275,7 @@ if __name__ == "__main__":
|
|||||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||||
(None, None, None),
|
(None, None, None),
|
||||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||||
]
|
]
|
||||||
|
|
||||||
for quant_dtype in quant_dtypes:
|
for quant_dtype in quant_dtypes:
|
||||||
|
|||||||
@ -8,11 +8,12 @@ import vllm.envs as envs
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
from vllm.compilation.fusion import FUSED_OPS, FusionPass
|
||||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,13 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||||
FusionPass, GroupShape, QuantKey)
|
FusionPass)
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape, QuantKey, ScaleDesc)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -30,10 +32,8 @@ class TestModel(torch.nn.Module):
|
|||||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||||
self.key = QuantKey(dtype=FP8_DTYPE,
|
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||||
static=static,
|
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||||
group_shape=group_shape,
|
|
||||||
symmetric=True)
|
|
||||||
if static:
|
if static:
|
||||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -11,9 +11,10 @@ from tests.models.utils import check_outputs_equal
|
|||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||||
create_common_attn_metadata)
|
create_common_attn_metadata)
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||||
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
|
from vllm.compilation.fusion import QUANT_OPS
|
||||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||||
from vllm.compilation.fx_utils import find_op_nodes
|
from vllm.compilation.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
@ -22,13 +23,14 @@ from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
|||||||
set_current_vllm_config)
|
set_current_vllm_config)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape)
|
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp)
|
Fp8LinearOp)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
# globals needed for string-import custom Dynamo backend field
|
# globals needed for string-import custom Dynamo backend field
|
||||||
backend: Optional[TestBackend] = None
|
backend: Optional[TestBackend] = None
|
||||||
@ -105,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
|||||||
|
|
||||||
# check support
|
# check support
|
||||||
attn_fusion_supported = [
|
attn_fusion_supported = [
|
||||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
layer.impl.fused_output_quant_supported(quant_key)
|
||||||
quant_key.static,
|
|
||||||
quant_key.group_shape)
|
|
||||||
for key, layer in compile_config.static_forward_context.items()
|
for key, layer in compile_config.static_forward_context.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -149,12 +149,12 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
|||||||
backend = None
|
backend = None
|
||||||
|
|
||||||
|
|
||||||
class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
class AttentionQuantPatternModel(torch.nn.Module):
|
||||||
"""Test model for AttentionStaticQuantPattern fusion."""
|
"""Base model for AttentionQuantPattern fusion."""
|
||||||
|
|
||||||
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
|
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
|
||||||
kv_cache_dtype: torch.dtype, device: torch.device,
|
kv_cache_dtype: torch.dtype, device: torch.device,
|
||||||
vllm_config: VllmConfig):
|
vllm_config: VllmConfig, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_qo_heads = num_qo_heads
|
self.num_qo_heads = num_qo_heads
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
@ -172,11 +172,6 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
|||||||
prefix="model.layers.0.self_attn.attn",
|
prefix="model.layers.0.self_attn.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fp8_linear = Fp8LinearOp(
|
|
||||||
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
|
||||||
self.wscale = torch.tensor([1.0], dtype=torch.float32)
|
|
||||||
self.scale = torch.tensor([1.0], dtype=torch.float32)
|
|
||||||
|
|
||||||
self.block_size = 16
|
self.block_size = 16
|
||||||
|
|
||||||
# Initialize attn MetadataBuilder
|
# Initialize attn MetadataBuilder
|
||||||
@ -230,23 +225,86 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
|||||||
|
|
||||||
return self.attn_metadata
|
return self.attn_metadata
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
w: torch.Tensor):
|
class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||||
|
"""Test model for AttentionFp8StaticQuantPattern fusion."""
|
||||||
|
|
||||||
|
quant_key = kFp8StaticTensorSym
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.fp8_linear = Fp8LinearOp(
|
||||||
|
act_quant_static=self.quant_key.scale.static,
|
||||||
|
act_quant_group_shape=self.quant_key.scale.group_shape)
|
||||||
|
|
||||||
|
hidden_size = self.num_qo_heads * self.head_size
|
||||||
|
self.w = kwargs.get(
|
||||||
|
"w", {
|
||||||
|
"weight":
|
||||||
|
torch.randn(hidden_size, hidden_size).to(
|
||||||
|
dtype=FP8_DTYPE, device=self.device).t(),
|
||||||
|
"wscale":
|
||||||
|
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||||
|
"scale":
|
||||||
|
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||||
|
})
|
||||||
|
|
||||||
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
"""Forward pass that creates the pattern to be fused."""
|
"""Forward pass that creates the pattern to be fused."""
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
return self.fp8_linear.apply(input=attn_output,
|
return self.fp8_linear.apply(input=attn_output,
|
||||||
weight=w,
|
weight=self.w["weight"],
|
||||||
weight_scale=self.wscale,
|
weight_scale=self.w["wscale"],
|
||||||
input_scale=self.scale)
|
input_scale=self.w["scale"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||||
|
"""Test model for AttentionNvfp4QuantPattern fusion."""
|
||||||
|
|
||||||
|
quant_key = kNvfp4Quant
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
hidden_size = self.num_qo_heads * self.head_size
|
||||||
|
self.w = kwargs.get(
|
||||||
|
"w", {
|
||||||
|
"weight":
|
||||||
|
torch.randint(256, (hidden_size, hidden_size // 2),
|
||||||
|
dtype=FP4_DTYPE,
|
||||||
|
device=self.device),
|
||||||
|
"wscale_swizzled":
|
||||||
|
torch.randn(hidden_size, hidden_size // 16).to(
|
||||||
|
dtype=FP8_DTYPE, device=self.device),
|
||||||
|
"wscale":
|
||||||
|
torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||||
|
"scale":
|
||||||
|
torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||||
|
})
|
||||||
|
|
||||||
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
|
"""Forward pass that creates the pattern to be fused."""
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
quant_output, output_block_scale = scaled_fp4_quant(
|
||||||
|
attn_output, 1 / self.w["scale"])
|
||||||
|
return cutlass_scaled_fp4_mm(a=quant_output,
|
||||||
|
b=self.w["weight"],
|
||||||
|
block_scale_a=output_block_scale,
|
||||||
|
block_scale_b=self.w["wscale_swizzled"],
|
||||||
|
alpha=self.w["scale"] * self.w["wscale"],
|
||||||
|
out_dtype=attn_output.dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
|
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
|
||||||
@pytest.mark.parametrize("head_size", [128])
|
@pytest.mark.parametrize("head_size", [128])
|
||||||
@pytest.mark.parametrize("batch_size", [7, 256, 533])
|
@pytest.mark.parametrize("batch_size", [7, 256, 533])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model_name, model_class",
|
||||||
"model_name, quant_key",
|
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||||
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)])
|
TestAttentionFp8StaticQuantPatternModel),
|
||||||
|
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||||
|
TestAttentionNvfp4QuantPatternModel)])
|
||||||
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
|
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
||||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||||
@ -255,8 +313,8 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
|||||||
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||||
head_size: int, batch_size: int,
|
head_size: int, batch_size: int,
|
||||||
dtype: torch.dtype, model_name: str,
|
dtype: torch.dtype, model_name: str,
|
||||||
quant_key: QuantKey, backend: _Backend,
|
model_class: type[AttentionQuantPatternModel],
|
||||||
monkeypatch, dist_init):
|
backend: _Backend, monkeypatch, dist_init):
|
||||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
@ -277,8 +335,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
cache_config=CacheConfig(cache_dtype="fp8"))
|
||||||
|
|
||||||
# Create test inputs
|
# Create test inputs
|
||||||
hidden_size = num_qo_heads * head_size
|
q = torch.randn(batch_size,
|
||||||
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
num_qo_heads * head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
k = torch.randn(batch_size,
|
k = torch.randn(batch_size,
|
||||||
num_kv_heads * head_size,
|
num_kv_heads * head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -287,7 +347,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
num_kv_heads * head_size,
|
num_kv_heads * head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device)
|
device=device)
|
||||||
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
|
|
||||||
|
|
||||||
# Mark first dimension as dynamic for realistic testing
|
# Mark first dimension as dynamic for realistic testing
|
||||||
torch._dynamo.mark_dynamic(q, 0)
|
torch._dynamo.mark_dynamic(q, 0)
|
||||||
@ -299,9 +358,12 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
|
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
|
||||||
attn_metadata=None, vllm_config=vllm_config_unfused
|
attn_metadata=None, vllm_config=vllm_config_unfused
|
||||||
), global_force_attn_backend_context_manager(backend):
|
), global_force_attn_backend_context_manager(backend):
|
||||||
model_unfused = TestAttentionStaticQuantPatternModel(
|
model_unfused = model_class(num_qo_heads=num_qo_heads,
|
||||||
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
|
num_kv_heads=num_kv_heads,
|
||||||
vllm_config_unfused)
|
head_size=head_size,
|
||||||
|
kv_cache_dtype=FP8_DTYPE,
|
||||||
|
device=device,
|
||||||
|
vllm_config=vllm_config_unfused)
|
||||||
model_unfused = model_unfused.to(device)
|
model_unfused = model_unfused.to(device)
|
||||||
|
|
||||||
forward_ctx = get_forward_context()
|
forward_ctx = get_forward_context()
|
||||||
@ -309,7 +371,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
batch_size)
|
batch_size)
|
||||||
|
|
||||||
# Run model directly without compilation and fusion
|
# Run model directly without compilation and fusion
|
||||||
result_unfused = model_unfused(q, k, v, linear_w)
|
result_unfused = model_unfused(q, k, v)
|
||||||
|
|
||||||
# Run model with attn fusion enabled
|
# Run model with attn fusion enabled
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
@ -317,9 +379,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
with set_current_vllm_config(vllm_config), set_forward_context(
|
with set_current_vllm_config(vllm_config), set_forward_context(
|
||||||
attn_metadata=None, vllm_config=vllm_config
|
attn_metadata=None, vllm_config=vllm_config
|
||||||
), global_force_attn_backend_context_manager(backend):
|
), global_force_attn_backend_context_manager(backend):
|
||||||
model_fused = TestAttentionStaticQuantPatternModel(
|
model_fused = model_class(num_qo_heads=num_qo_heads,
|
||||||
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
|
num_kv_heads=num_kv_heads,
|
||||||
vllm_config)
|
head_size=head_size,
|
||||||
|
kv_cache_dtype=FP8_DTYPE,
|
||||||
|
device=device,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
w=model_unfused.w)
|
||||||
model_fused = model_fused.to(device)
|
model_fused = model_fused.to(device)
|
||||||
|
|
||||||
forward_ctx = get_forward_context()
|
forward_ctx = get_forward_context()
|
||||||
@ -336,21 +402,20 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
backend=test_backend,
|
backend=test_backend,
|
||||||
fullgraph=True)
|
fullgraph=True)
|
||||||
assert model_compiled.attn._o_scale_float is None
|
assert model_compiled.attn._o_scale_float is None
|
||||||
result_fused_1 = model_compiled(q, k, v, linear_w)
|
result_fused_1 = model_compiled(q, k, v)
|
||||||
|
|
||||||
# After the 1st round of the forward pass, output quant scale should be
|
# After the 1st round of the forward pass, output quant scale should be
|
||||||
# loaded into the attn layer's _o_scale_float, the 2nd round should
|
# loaded into the attn layer's _o_scale_float, the 2nd round should
|
||||||
# reuse the loaded _o_scale_float
|
# reuse the loaded _o_scale_float
|
||||||
assert model_compiled.attn._o_scale_float is not None
|
assert model_compiled.attn._o_scale_float is not None
|
||||||
result_fused_2 = model_compiled(q, k, v, linear_w)
|
result_fused_2 = model_compiled(q, k, v)
|
||||||
assert model_compiled.attn._o_scale_float is not None
|
assert model_compiled.attn._o_scale_float is not None
|
||||||
|
|
||||||
# Check attn fusion support
|
# Check attn fusion support
|
||||||
|
quant_key = model_class.quant_key
|
||||||
attn_fusion_supported = [
|
attn_fusion_supported = [
|
||||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
|
||||||
quant_key.static,
|
vllm_config.compilation_config.static_forward_context.items()
|
||||||
quant_key.group_shape) for key,
|
|
||||||
layer in vllm_config.compilation_config.static_forward_context.items()
|
|
||||||
]
|
]
|
||||||
if any(attn_fusion_supported):
|
if any(attn_fusion_supported):
|
||||||
# Check quantization ops in the graph before and after fusion
|
# Check quantization ops in the graph before and after fusion
|
||||||
@ -370,6 +435,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
|
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
|
||||||
"Attention should have output_scale after fusion"
|
"Attention should have output_scale after fusion"
|
||||||
|
|
||||||
|
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
|
||||||
|
"Attention should not have output_block_scale before fusion"
|
||||||
|
if quant_key.dtype == FP8_DTYPE:
|
||||||
|
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
|
||||||
|
"Attention should not have output_block_scale after FP8 fusion"
|
||||||
|
elif quant_key.dtype == FP4_DTYPE:
|
||||||
|
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
|
||||||
|
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
|
||||||
|
|
||||||
# Check that results are closed
|
# Check that results are closed
|
||||||
torch.testing.assert_close(result_unfused,
|
torch.testing.assert_close(result_unfused,
|
||||||
result_fused_1,
|
result_fused_1,
|
||||||
|
|||||||
@ -6,7 +6,11 @@ import flashinfer
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||||
|
FLOAT8_E4M3_MAX,
|
||||||
|
dequantize_nvfp4_to_dtype)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
if not current_platform.is_device_capability(100):
|
if not current_platform.is_device_capability(100):
|
||||||
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||||
@ -14,6 +18,7 @@ if not current_platform.is_device_capability(100):
|
|||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
|
|
||||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
@ -29,7 +34,9 @@ DTYPE = [torch.bfloat16]
|
|||||||
QUANT_DTYPES = [
|
QUANT_DTYPES = [
|
||||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||||
(None, None, None),
|
(None, None, None),
|
||||||
|
(None, FP8_DTYPE, None),
|
||||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||||
]
|
]
|
||||||
BATCH_SIZE = [4, 12]
|
BATCH_SIZE = [4, 12]
|
||||||
MAX_SEQ_LENS = [(1024, 4096)]
|
MAX_SEQ_LENS = [(1024, 4096)]
|
||||||
@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||||
o_scale = 1.0
|
o_scale = 1.0
|
||||||
|
o_sf_scale = None
|
||||||
if o_quant_dtype == FP8_DTYPE:
|
if o_quant_dtype == FP8_DTYPE:
|
||||||
_, o_scale = to_float8(output)
|
_, o_scale = to_float8(output)
|
||||||
|
elif o_quant_dtype == FP4_DTYPE:
|
||||||
|
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||||
|
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||||
|
|
||||||
# TRTLLM Decode
|
# TRTLLM Decode
|
||||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
if o_quant_dtype == FP4_DTYPE:
|
||||||
|
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||||
|
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||||
|
dtype=torch.uint8),
|
||||||
|
torch.empty((round_up(query.shape[0], 128),
|
||||||
|
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||||
|
dtype=torch.float8_e4m3fn),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||||
|
|
||||||
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||||
query=query,
|
query=query,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||||
bmm2_scale=v_scale / o_scale,
|
bmm2_scale=v_scale / o_scale,
|
||||||
|
o_sf_scale=o_sf_scale,
|
||||||
out=output_trtllm,
|
out=output_trtllm,
|
||||||
)
|
)
|
||||||
if o_quant_dtype == FP8_DTYPE:
|
if o_quant_dtype == FP8_DTYPE:
|
||||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||||
|
elif o_quant_dtype == FP4_DTYPE:
|
||||||
|
output_trtllm.data = output_trtllm.data.reshape(
|
||||||
|
-1, query.shape[1] * query.shape[2] // 2)
|
||||||
|
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||||
|
output_trtllm.scale,
|
||||||
|
o_sf_scale, dtype,
|
||||||
|
query.device)
|
||||||
|
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||||
|
query.shape[2])
|
||||||
|
|
||||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||||
|
rtol, atol = 3e-1, 1e0
|
||||||
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||||
rtol, atol = 5e-2, 7e-2
|
rtol, atol = 5e-2, 7e-2
|
||||||
else:
|
else:
|
||||||
rtol, atol = 1e-2, 1e-2
|
rtol, atol = 1e-2, 2e-2
|
||||||
|
|
||||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||||
@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
kv_quant_dtype = kv_quant_dtype or dtype
|
kv_quant_dtype = kv_quant_dtype or dtype
|
||||||
o_quant_dtype = o_quant_dtype or dtype
|
o_quant_dtype = o_quant_dtype or dtype
|
||||||
|
|
||||||
|
if q_quant_dtype != kv_quant_dtype:
|
||||||
|
pytest.skip("Skipped mixed QKV dtypes for prefill")
|
||||||
|
|
||||||
max_q_len, max_kv_len = max_seq_lens
|
max_q_len, max_kv_len = max_seq_lens
|
||||||
|
|
||||||
num_qo_heads, num_kv_heads = num_heads
|
num_qo_heads, num_kv_heads = num_heads
|
||||||
@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||||
o_scale = 1.0
|
o_scale = 1.0
|
||||||
|
o_sf_scale = None
|
||||||
if o_quant_dtype == FP8_DTYPE:
|
if o_quant_dtype == FP8_DTYPE:
|
||||||
_, o_scale = to_float8(output)
|
_, o_scale = to_float8(output)
|
||||||
|
elif o_quant_dtype == FP4_DTYPE:
|
||||||
|
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||||
|
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||||
|
|
||||||
# TRTLLM Prefill
|
# TRTLLM Prefill
|
||||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
if o_quant_dtype == FP4_DTYPE:
|
||||||
|
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||||
|
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||||
|
dtype=torch.uint8),
|
||||||
|
torch.empty((round_up(query.shape[0], 128),
|
||||||
|
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||||
|
dtype=torch.float8_e4m3fn),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||||
|
|
||||||
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||||
query=query,
|
query=query,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
cum_seq_lens_q=q_indptr,
|
cum_seq_lens_q=q_indptr,
|
||||||
cum_seq_lens_kv=kv_indptr,
|
cum_seq_lens_kv=kv_indptr,
|
||||||
|
o_sf_scale=o_sf_scale,
|
||||||
out=output_trtllm,
|
out=output_trtllm,
|
||||||
)
|
)
|
||||||
if o_quant_dtype == FP8_DTYPE:
|
if o_quant_dtype == FP8_DTYPE:
|
||||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||||
|
elif o_quant_dtype == FP4_DTYPE:
|
||||||
|
output_trtllm.data = output_trtllm.data.reshape(
|
||||||
|
-1, query.shape[1] * query.shape[2] // 2)
|
||||||
|
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||||
|
output_trtllm.scale,
|
||||||
|
o_sf_scale, dtype,
|
||||||
|
query.device)
|
||||||
|
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||||
|
query.shape[2])
|
||||||
|
|
||||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||||
|
rtol, atol = 4e-1, 1e0
|
||||||
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||||
rtol, atol = 5e-2, 7e-2
|
rtol, atol = 5e-2, 7e-2
|
||||||
else:
|
else:
|
||||||
rtol, atol = 1e-2, 1e-2
|
rtol, atol = 1e-2, 1e-2
|
||||||
|
|||||||
@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
GroupShape)
|
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
attn_metadata: T,
|
attn_metadata: T,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
group_shape: GroupShape):
|
|
||||||
"""
|
"""
|
||||||
Does this attention implementation support fused output quantization.
|
Does this attention implementation support fused output quantization.
|
||||||
This is used by the AttnFusionPass to only fuse output quantization
|
This is used by the AttnFusionPass to only fuse output quantization
|
||||||
onto implementations that support it.
|
onto implementations that support it.
|
||||||
|
|
||||||
TODO(luka) merge parameters into QuantDescriptor
|
:param quant_key: QuantKey object that describes the quantization op
|
||||||
:param dtype: quantized dtype
|
|
||||||
:param static: static or dynamic quantization
|
|
||||||
:param group_shape: quant group shape.
|
|
||||||
:return: is fusion supported for this type of quantization
|
:return: is fusion supported for this type of quantization
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
|||||||
attn_metadata: T,
|
attn_metadata: T,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -800,6 +800,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: DifferentialFlashAttentionMetadata,
|
attn_metadata: DifferentialFlashAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
@ -817,6 +818,11 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
|||||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||||
We use torch's .expand() to avoid duplicating values
|
We use torch's .expand() to avoid duplicating values
|
||||||
"""
|
"""
|
||||||
|
if output_scale is not None or output_block_scale is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"fused output quantization is not yet supported"
|
||||||
|
" for DifferentialFlashAttentionImpl")
|
||||||
|
|
||||||
if self.lambda_full is None:
|
if self.lambda_full is None:
|
||||||
self.lambda_init = self.differential_flash_attention_config[
|
self.lambda_init = self.differential_flash_attention_config[
|
||||||
"lambda_init"]
|
"lambda_init"]
|
||||||
|
|||||||
@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
attn_metadata: DualChunkFlashAttentionMetadata,
|
attn_metadata: DualChunkFlashAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with DualChunkFlashAttention.
|
"""Forward pass with DualChunkFlashAttention.
|
||||||
Args:
|
Args:
|
||||||
@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
"""
|
"""
|
||||||
assert output is None, "Output tensor not supported for DualChunk"
|
assert output is None, "Output tensor not supported for DualChunk"
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for FlashAttentionImpl")
|
" for FlashAttentionImpl")
|
||||||
|
|||||||
@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
@ -615,7 +616,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for FlashAttentionImpl")
|
" for FlashAttentionImpl")
|
||||||
|
|||||||
@ -1238,12 +1238,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
attn_metadata: T,
|
attn_metadata: T,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if output is not None:
|
if output is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"output is not yet supported for MLAImplBase")
|
"output is not yet supported for MLAImplBase")
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for MLAImplBase")
|
" for MLAImplBase")
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape)
|
QuantKey, kFp8StaticTensorSym)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -529,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
head_dim).reshape(tokens, n_kv_heads * n_rep,
|
head_dim).reshape(tokens, n_kv_heads * n_rep,
|
||||||
head_dim))
|
head_dim))
|
||||||
|
|
||||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
group_shape: GroupShape):
|
|
||||||
if self.use_triton_flash_attn:
|
if self.use_triton_flash_attn:
|
||||||
return dtype == current_platform.fp8_dtype(
|
return quant_key == kFp8StaticTensorSym
|
||||||
) and static and group_shape == GroupShape.PER_TENSOR
|
|
||||||
|
|
||||||
# Only supported in the Triton backend
|
# Only supported in the Triton backend
|
||||||
return False
|
return False
|
||||||
@ -548,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: ROCmFlashAttentionMetadata,
|
attn_metadata: ROCmFlashAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
@ -606,6 +605,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
"fused output quantization only supported for Triton"
|
"fused output quantization only supported for Triton"
|
||||||
" implementation in ROCMFlashAttentionImpl for now")
|
" implementation in ROCMFlashAttentionImpl for now")
|
||||||
|
|
||||||
|
if output_block_scale is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"fused nvfp4 output quantization is not supported"
|
||||||
|
" for ROCMFlashAttentionImpl")
|
||||||
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
if key is not None:
|
if key is not None:
|
||||||
assert value is not None
|
assert value is not None
|
||||||
|
|||||||
@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
attn_metadata: "XFormersMetadata",
|
attn_metadata: "XFormersMetadata",
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
|
|
||||||
@ -484,7 +485,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for XFormersImpl")
|
" for XFormersImpl")
|
||||||
|
|||||||
@ -495,6 +495,7 @@ def unified_attention_with_output(
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
wait_for_kv_layer_from_connector(layer_name)
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
@ -510,7 +511,8 @@ def unified_attention_with_output(
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
output=output,
|
output=output,
|
||||||
output_scale=output_scale)
|
output_scale=output_scale,
|
||||||
|
output_block_scale=output_block_scale)
|
||||||
|
|
||||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
|
||||||
@ -522,6 +524,7 @@ def unified_attention_with_output_fake(
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -529,7 +532,7 @@ def unified_attention_with_output_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="unified_attention_with_output",
|
op_name="unified_attention_with_output",
|
||||||
op_func=unified_attention_with_output,
|
op_func=unified_attention_with_output,
|
||||||
mutates_args=["output"],
|
mutates_args=["output", "output_block_scale"],
|
||||||
fake_impl=unified_attention_with_output_fake,
|
fake_impl=unified_attention_with_output_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -12,7 +12,8 @@ from torch._ops import OpOverload
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape)
|
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
|
||||||
|
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .fx_utils import find_getitem_maybe
|
from .fx_utils import find_getitem_maybe
|
||||||
@ -21,6 +22,7 @@ from .vllm_inductor_pass import VllmInductorPass
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
|
|
||||||
def empty_bf16(*args, **kwargs):
|
def empty_bf16(*args, **kwargs):
|
||||||
@ -31,42 +33,13 @@ def empty_fp32(*args, **kwargs):
|
|||||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def empty_i32(*args, **kwargs):
|
||||||
|
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
RMS_OP = torch.ops._C.rms_norm.default
|
RMS_OP = torch.ops._C.rms_norm.default
|
||||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||||
|
|
||||||
|
|
||||||
class QuantKey(NamedTuple):
|
|
||||||
"""
|
|
||||||
Named tuple for identifying the type of quantization.
|
|
||||||
dtype: quantized data type
|
|
||||||
static: static quantization if True, dynamic if False
|
|
||||||
group_shape: quantization group shape
|
|
||||||
symmetric: symmetric if True, asymmetric if False
|
|
||||||
|
|
||||||
TODO(luka) use QuantDescriptor once standardized:
|
|
||||||
https://github.com/vllm-project/vllm/issues/8913
|
|
||||||
|
|
||||||
"""
|
|
||||||
dtype: torch.dtype
|
|
||||||
static: bool
|
|
||||||
group_shape: GroupShape
|
|
||||||
symmetric: bool = True
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
group_shape = ('per_tensor'
|
|
||||||
if self.group_shape == GroupShape.PER_TENSOR else
|
|
||||||
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
|
||||||
else str(self.group_shape)))
|
|
||||||
|
|
||||||
return (f"QuantKey({'static' if self.static else 'dynamic'},"
|
|
||||||
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
|
|
||||||
f"{'a' if not self.symmetric else ''}symmetric)")
|
|
||||||
|
|
||||||
|
|
||||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
|
|
||||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
|
|
||||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
|
|
||||||
|
|
||||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||||
kFp8StaticTensorSym:
|
kFp8StaticTensorSym:
|
||||||
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||||
@ -74,6 +47,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
|||||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||||
kFp8DynamicTokenSym:
|
kFp8DynamicTokenSym:
|
||||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||||
|
kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -187,11 +161,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
|||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
symmetric=True):
|
symmetric=True):
|
||||||
fused_key = FusedRMSQuantKey(fused_add=False,
|
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||||
quant=QuantKey(
|
quant=QuantKey(dtype=quant_dtype,
|
||||||
dtype=quant_dtype,
|
scale=kStaticTensorScale,
|
||||||
static=True,
|
symmetric=symmetric))
|
||||||
group_shape=GroupShape.PER_TENSOR,
|
|
||||||
symmetric=symmetric))
|
|
||||||
super().__init__(epsilon, fused_key)
|
super().__init__(epsilon, fused_key)
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass):
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
@ -244,11 +216,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
|||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
symmetric=True):
|
symmetric=True):
|
||||||
key = FusedRMSQuantKey(fused_add=True,
|
key = FusedRMSQuantKey(fused_add=True,
|
||||||
quant=QuantKey(
|
quant=QuantKey(dtype=quant_dtype,
|
||||||
dtype=quant_dtype,
|
scale=kStaticTensorScale,
|
||||||
static=True,
|
symmetric=symmetric))
|
||||||
group_shape=GroupShape.PER_TENSOR,
|
|
||||||
symmetric=symmetric))
|
|
||||||
super().__init__(epsilon, key)
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass,
|
def register(self, pm_pass: PatternMatcherPass,
|
||||||
@ -337,10 +307,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||||
symmetric=True):
|
symmetric=True):
|
||||||
|
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||||
key = FusedRMSQuantKey(fused_add=False,
|
key = FusedRMSQuantKey(fused_add=False,
|
||||||
quant=QuantKey(dtype=quant_dtype,
|
quant=QuantKey(dtype=quant_dtype,
|
||||||
static=False,
|
scale=scale,
|
||||||
group_shape=group_shape,
|
|
||||||
symmetric=symmetric))
|
symmetric=symmetric))
|
||||||
super().__init__(epsilon, key)
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
@ -435,10 +405,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||||
symmetric=True):
|
symmetric=True):
|
||||||
|
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||||
key = FusedRMSQuantKey(fused_add=True,
|
key = FusedRMSQuantKey(fused_add=True,
|
||||||
quant=QuantKey(dtype=quant_dtype,
|
quant=QuantKey(dtype=quant_dtype,
|
||||||
static=False,
|
scale=scale,
|
||||||
group_shape=group_shape,
|
|
||||||
symmetric=symmetric))
|
symmetric=symmetric))
|
||||||
super().__init__(epsilon, key)
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
@ -11,44 +13,41 @@ from torch._subclasses.fake_tensor import (FakeTensorMode,
|
|||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
QuantKey, kNvfp4Quant, kStaticTensorScale)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
|
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||||
|
|
||||||
|
|
||||||
class AttentionStaticQuantPattern:
|
class AttentionQuantPattern(ABC):
|
||||||
"""
|
"""
|
||||||
Fusion for Attention+StaticQuant.
|
The base class for Attn+Quant fusions.
|
||||||
|
Should not be used directly.
|
||||||
Only triggers when the attention implementation returns True in
|
|
||||||
`fused_output_quant_supported()`. If the pattern is found, the StaticQuant
|
|
||||||
op will be removed from the graph, and its scale will be passed into
|
|
||||||
Attention op as the `output_scale` argument.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer: Attention,
|
layer: Attention,
|
||||||
quant_dtype: torch.dtype,
|
quant_key: QuantKey,
|
||||||
symmetric=True,
|
|
||||||
):
|
):
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
self.layer_name = layer.layer_name
|
self.layer_name = layer.layer_name
|
||||||
self.num_heads = layer.num_heads
|
self.num_heads = layer.num_heads
|
||||||
self.head_size = layer.head_size
|
self.head_size = layer.head_size
|
||||||
self.quant_dtype = quant_dtype
|
self.quant_key = quant_key
|
||||||
self.quant_key = QuantKey(dtype=quant_dtype,
|
self.quant_dtype = quant_key.dtype
|
||||||
static=True,
|
|
||||||
group_shape=GroupShape.PER_TENSOR,
|
|
||||||
symmetric=symmetric)
|
|
||||||
assert self.quant_key in QUANT_OPS, \
|
assert self.quant_key in QUANT_OPS, \
|
||||||
f"unsupported quantization scheme {self.quant_key}"
|
f"unsupported quantization scheme {self.quant_key}"
|
||||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||||
@ -57,12 +56,49 @@ class AttentionStaticQuantPattern:
|
|||||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||||
return torch.empty(*args, **kwargs)
|
return torch.empty(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wrap_trace_fn(process_fx, trace_fn):
|
||||||
|
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
return process_fx(trace_fn(*args, **kwargs))
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||||
|
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||||
|
view_to_reshape(gm)
|
||||||
|
return gm
|
||||||
|
|
||||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||||
if self.layer.impl.fused_output_quant_supported(
|
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||||
self.quant_dtype, self.quant_key.static,
|
|
||||||
self.quant_key.group_shape):
|
|
||||||
self._register(pm_pass)
|
self._register(pm_pass)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _register(self, pm_pass: PatternMatcherPass):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||||
|
"""
|
||||||
|
Fusion for Attention+Fp8StaticQuant.
|
||||||
|
|
||||||
|
Only triggers when the attention implementation returns True in
|
||||||
|
`fused_output_quant_supported()`. If the pattern is found, the
|
||||||
|
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||||
|
will be passed into Attention op as the `output_scale` argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer: Attention,
|
||||||
|
symmetric: bool = True,
|
||||||
|
):
|
||||||
|
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||||
|
scale=kStaticTensorScale,
|
||||||
|
symmetric=symmetric)
|
||||||
|
super().__init__(layer, quant_key)
|
||||||
|
|
||||||
def _register(self, pm_pass: PatternMatcherPass):
|
def _register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
@ -74,9 +110,10 @@ class AttentionStaticQuantPattern:
|
|||||||
value=v,
|
value=v,
|
||||||
output=output_attn,
|
output=output_attn,
|
||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=None)
|
output_scale=None,
|
||||||
attn_out_view = RESHAPE_OP(at1[1],
|
output_block_scale=None)
|
||||||
[-1, self.num_heads * self.head_size])
|
attn_out_view = RESHAPE_OP(
|
||||||
|
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||||
at2 = auto_functionalized(self.QUANT_OP,
|
at2 = auto_functionalized(self.QUANT_OP,
|
||||||
result=output_quant,
|
result=output_quant,
|
||||||
input=attn_out_view,
|
input=attn_out_view,
|
||||||
@ -98,7 +135,8 @@ class AttentionStaticQuantPattern:
|
|||||||
value=v,
|
value=v,
|
||||||
output=output_attn,
|
output=output_attn,
|
||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=scale)
|
output_scale=scale,
|
||||||
|
output_block_scale=None)
|
||||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||||
|
|
||||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||||
@ -114,21 +152,94 @@ class AttentionStaticQuantPattern:
|
|||||||
empty_fp32(1, 1) # scale
|
empty_fp32(1, 1) # scale
|
||||||
]
|
]
|
||||||
|
|
||||||
def wrap_trace_fn(process_fx, trace_fn):
|
pm.register_replacement(
|
||||||
|
pattern, replacement, inputs,
|
||||||
|
AttentionQuantPattern.wrap_trace_fn(
|
||||||
|
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||||
|
pm_pass)
|
||||||
|
|
||||||
def wrapped(*args, **kwargs):
|
|
||||||
return process_fx(trace_fn(*args, **kwargs))
|
|
||||||
|
|
||||||
return wrapped
|
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||||
|
"""
|
||||||
|
Fusion for Attention+Nvfp4Quant.
|
||||||
|
|
||||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
Only triggers when the attention implementation returns True in
|
||||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
`fused_output_quant_supported()`. If the pattern is found, the
|
||||||
view_to_reshape(gm)
|
Nvfp4Quant op will be removed from the graph, and its scale
|
||||||
return gm
|
will be passed into Attention op as the `output_scale` argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layer: Attention):
|
||||||
|
super().__init__(layer, kNvfp4Quant)
|
||||||
|
|
||||||
|
def _register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||||
|
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||||
|
at1 = auto_functionalized(ATTN_OP,
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
output=output_attn,
|
||||||
|
layer_name=self.layer_name,
|
||||||
|
output_scale=None,
|
||||||
|
output_block_scale=None)
|
||||||
|
attn_out_view = RESHAPE_OP(
|
||||||
|
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||||
|
at2 = auto_functionalized(self.QUANT_OP,
|
||||||
|
output=output_quant,
|
||||||
|
input=attn_out_view,
|
||||||
|
output_scale=output_scale,
|
||||||
|
input_scale=input_scale)
|
||||||
|
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||||
|
return at2[1], output_scale_view
|
||||||
|
|
||||||
|
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||||
|
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||||
|
# attention output in quant_dtype
|
||||||
|
output_attn = torch.ops.aten.full.default(
|
||||||
|
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||||
|
0.0,
|
||||||
|
dtype=self.quant_dtype,
|
||||||
|
device=q.device)
|
||||||
|
# attention output block scale
|
||||||
|
output_scale_view = torch.ops.aten.view.dtype(
|
||||||
|
output_scale, FP8_DTYPE)
|
||||||
|
at2 = auto_functionalized(ATTN_OP,
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
output=output_attn,
|
||||||
|
layer_name=self.layer_name,
|
||||||
|
output_scale=input_scale,
|
||||||
|
output_block_scale=output_scale_view)
|
||||||
|
output = RESHAPE_OP(at2[1],
|
||||||
|
[-1, self.num_heads * self.head_size // 2])
|
||||||
|
return output, at2[2]
|
||||||
|
|
||||||
|
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||||
|
# That would not work for the unified_attention custom op.
|
||||||
|
with unset_fake_temporarily(), FakeTensorMode():
|
||||||
|
inputs = [
|
||||||
|
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||||
|
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||||
|
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||||
|
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||||
|
self.empty_quant(5, self.num_heads * self.head_size //
|
||||||
|
2), # output_quant
|
||||||
|
empty_i32(128,
|
||||||
|
round_up(self.num_heads * self.head_size // 16,
|
||||||
|
4)), # output_scale
|
||||||
|
empty_fp32(1, 1), # input_scale
|
||||||
|
]
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, inputs,
|
pattern, replacement, inputs,
|
||||||
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
|
AttentionQuantPattern.wrap_trace_fn(
|
||||||
|
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||||
|
pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AttnFusionPass(VllmInductorPass):
|
class AttnFusionPass(VllmInductorPass):
|
||||||
@ -151,8 +262,12 @@ class AttnFusionPass(VllmInductorPass):
|
|||||||
|
|
||||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||||
for layer_name, layer in attn_layers.items():
|
for layer_name, layer in attn_layers.items():
|
||||||
pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE)
|
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
|
||||||
pattern.register_if_supported(self.patterns)
|
pattern_fp8.register_if_supported(self.patterns)
|
||||||
|
|
||||||
|
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
|
||||||
|
pattern_nvfp4.register_if_supported(self.patterns)
|
||||||
|
|
||||||
if len(attn_layers) == 0:
|
if len(attn_layers) == 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Attention + quant fusion is enabled, but no attention layers "
|
"Attention + quant fusion is enabled, but no attention layers "
|
||||||
@ -175,4 +290,6 @@ class AttnFusionPass(VllmInductorPass):
|
|||||||
self.end_and_log()
|
self.end_and_log()
|
||||||
|
|
||||||
def uuid(self):
|
def uuid(self):
|
||||||
return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern)
|
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
|
||||||
|
AttentionFp8StaticQuantPattern,
|
||||||
|
AttentionNvfp4QuantPattern)
|
||||||
|
|||||||
@ -2,16 +2,21 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""This file is used for /tests and /benchmarks"""
|
"""This file is used for /tests and /benchmarks"""
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import ClassVar, NamedTuple, Optional
|
from typing import ClassVar, NamedTuple, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
from torch import fx
|
||||||
|
|
||||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
|
|
||||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||||
class _GroupShape(NamedTuple):
|
class _GroupShape(NamedTuple):
|
||||||
@ -34,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
|||||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ScaleDesc:
|
||||||
|
"""
|
||||||
|
Class for describing a single quantization scaling factor.
|
||||||
|
dtype: data type of the scale
|
||||||
|
static: static scale if True, dynamic if False
|
||||||
|
group_shape: group shape of the scale
|
||||||
|
"""
|
||||||
|
dtype: torch.dtype
|
||||||
|
static: bool
|
||||||
|
group_shape: GroupShape
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
group_shape = ('per_tensor'
|
||||||
|
if self.group_shape == GroupShape.PER_TENSOR else
|
||||||
|
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
||||||
|
else str(self.group_shape)))
|
||||||
|
|
||||||
|
return (f"{fx.graph.dtype_abbrs[self.dtype]},"
|
||||||
|
f"{'static' if self.static else 'dynamic'},{group_shape}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class QuantKey:
|
||||||
|
"""
|
||||||
|
Class for identifying the type of quantization.
|
||||||
|
dtype: quantized data type
|
||||||
|
scale: scale descriptor
|
||||||
|
scale2: second-level scale descriptor
|
||||||
|
symmetric: symmetric if True, asymmetric if False
|
||||||
|
"""
|
||||||
|
dtype: torch.dtype
|
||||||
|
scale: ScaleDesc
|
||||||
|
scale2: Optional[ScaleDesc] = None
|
||||||
|
symmetric: bool = True
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
|
||||||
|
return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
|
||||||
|
f"scale({self.scale}),{scale2_str}"
|
||||||
|
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||||
|
|
||||||
|
|
||||||
|
kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR)
|
||||||
|
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
|
||||||
|
|
||||||
|
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
|
||||||
|
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
|
||||||
|
|
||||||
|
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
|
||||||
|
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
|
||||||
|
|
||||||
|
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
|
||||||
|
kNvfp4Quant = QuantKey(FP4_DTYPE,
|
||||||
|
scale=kNvfp4GroupScale,
|
||||||
|
scale2=kStaticTensorScale)
|
||||||
|
|
||||||
|
|
||||||
# Normalize the group_shape to the full extent for any dims that are -1
|
# Normalize the group_shape to the full extent for any dims that are -1
|
||||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||||
# -1 means full extent
|
# -1 means full extent
|
||||||
|
|||||||
@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
|
|
||||||
@ -497,7 +498,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for TorchSDPABackendImpl")
|
" for TorchSDPABackendImpl")
|
||||||
|
|||||||
@ -430,6 +430,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
@ -447,7 +448,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for FlashAttentionImpl")
|
" for FlashAttentionImpl")
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
|||||||
MultiLevelCascadeAttentionWrapper)
|
MultiLevelCascadeAttentionWrapper)
|
||||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||||
|
from flashinfer.utils import FP4Tensor
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape)
|
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv, is_pin_memory_available
|
from vllm.utils import cdiv, is_pin_memory_available
|
||||||
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
||||||
@ -40,6 +41,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -653,14 +655,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
and num_heads % num_kv_heads == 0)
|
and num_heads % num_kv_heads == 0)
|
||||||
self.bmm1_scale: Optional[float] = None
|
self.bmm1_scale: Optional[float] = None
|
||||||
self.bmm2_scale: Optional[float] = None
|
self.bmm2_scale: Optional[float] = None
|
||||||
|
self.o_sf_scale: Optional[float] = None
|
||||||
|
|
||||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
group_shape: GroupShape):
|
|
||||||
supported_quant_type = (dtype == FP8_DTYPE and static
|
|
||||||
and group_shape == GroupShape.PER_TENSOR)
|
|
||||||
return (self.support_trtllm_attn
|
return (self.support_trtllm_attn
|
||||||
and self.kv_cache_dtype.startswith("fp8")
|
and self.kv_cache_dtype.startswith("fp8")
|
||||||
and supported_quant_type)
|
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -672,6 +672,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashInferMetadata,
|
attn_metadata: FlashInferMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashInfer.
|
"""Forward pass with FlashInfer.
|
||||||
|
|
||||||
@ -705,19 +706,32 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
if output_scale is None:
|
if output_scale is None:
|
||||||
assert attn_metadata.q_data_type != FP8_DTYPE, \
|
assert attn_metadata.q_data_type != FP8_DTYPE, \
|
||||||
"Query can only be FP8 if output fusion happened."
|
"Query can only be FP8 if output fusion happened."
|
||||||
|
assert output_block_scale is None, "output_block_scale "\
|
||||||
|
"is not supported when fusion has not happened"
|
||||||
else:
|
else:
|
||||||
assert attn_metadata.q_data_type == FP8_DTYPE, \
|
assert attn_metadata.q_data_type == FP8_DTYPE, \
|
||||||
"Query must be FP8 when attn+quant fusion happened."
|
"Query must be FP8 when attn+quant fusion happened."
|
||||||
assert (attn_metadata.prefill_use_trtllm and
|
assert (attn_metadata.prefill_use_trtllm and
|
||||||
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
|
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
|
||||||
assert output.dtype == FP8_DTYPE, \
|
|
||||||
"Output must be FP8 when attn+quant fusion happened."
|
|
||||||
|
|
||||||
# TRTLLM attn kernel requires o scale as a host scalar, store the
|
if output.dtype == FP8_DTYPE:
|
||||||
# o scale to host scalar in warmup run with cuda graph not enabled
|
assert output_block_scale is None, \
|
||||||
|
"output_block_scale should not be provided for fp8 output"
|
||||||
|
elif output.dtype == FP4_DTYPE:
|
||||||
|
assert output_block_scale is not None, \
|
||||||
|
"output_block_scale is required for nvfp4 output"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported output dtype: {output.dtype}")
|
||||||
|
|
||||||
|
# TRTLLM attn kernel requires o scale to pass as a host scalar,
|
||||||
|
# store the o scale as a host scalar in warmup run with cuda graph
|
||||||
|
# not enabled
|
||||||
if layer._o_scale_float is None:
|
if layer._o_scale_float is None:
|
||||||
layer._o_scale_float = output_scale.cpu().item()
|
layer._o_scale_float = output_scale.cpu().item()
|
||||||
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
|
if output.dtype == FP8_DTYPE:
|
||||||
|
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
|
||||||
|
elif output.dtype == FP4_DTYPE:
|
||||||
|
self.o_sf_scale = layer._o_scale_float
|
||||||
|
|
||||||
# Insert FP8 quant for query
|
# Insert FP8 quant for query
|
||||||
num_tokens, num_heads, head_size = query.shape
|
num_tokens, num_heads, head_size = query.shape
|
||||||
@ -818,6 +832,16 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert block_tables_prefill.is_contiguous()
|
assert block_tables_prefill.is_contiguous()
|
||||||
assert seq_lens_prefill.is_contiguous()
|
assert seq_lens_prefill.is_contiguous()
|
||||||
|
|
||||||
|
if output.dtype == FP4_DTYPE:
|
||||||
|
assert self.o_sf_scale is not None
|
||||||
|
out = FP4Tensor(data=output[num_decode_tokens:],
|
||||||
|
scale=output_block_scale,
|
||||||
|
scale_start_index=num_decode_tokens,
|
||||||
|
original_shape=prefill_query.shape)
|
||||||
|
else:
|
||||||
|
assert self.o_sf_scale is None
|
||||||
|
out = output[num_decode_tokens:]
|
||||||
|
|
||||||
trtllm_batch_context_with_kv_cache(
|
trtllm_batch_context_with_kv_cache(
|
||||||
query=prefill_query,
|
query=prefill_query,
|
||||||
kv_cache=kv_cache_permute,
|
kv_cache=kv_cache_permute,
|
||||||
@ -833,7 +857,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||||
window_left=self.window_left,
|
window_left=self.window_left,
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
out=output[num_decode_tokens:],
|
o_sf_scale=self.o_sf_scale,
|
||||||
|
out=out,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_decode_tokens > 0:
|
if num_decode_tokens > 0:
|
||||||
@ -870,6 +895,16 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert block_tables_decode.is_contiguous()
|
assert block_tables_decode.is_contiguous()
|
||||||
assert seq_lens_decode.is_contiguous()
|
assert seq_lens_decode.is_contiguous()
|
||||||
|
|
||||||
|
if output.dtype == FP4_DTYPE:
|
||||||
|
assert self.o_sf_scale is not None
|
||||||
|
out = FP4Tensor(data=output[:num_decode_tokens],
|
||||||
|
scale=output_block_scale,
|
||||||
|
scale_start_index=0,
|
||||||
|
original_shape=decode_query.shape)
|
||||||
|
else:
|
||||||
|
assert self.o_sf_scale is None
|
||||||
|
out = output[:num_decode_tokens]
|
||||||
|
|
||||||
trtllm_batch_decode_with_kv_cache(
|
trtllm_batch_decode_with_kv_cache(
|
||||||
query=decode_query,
|
query=decode_query,
|
||||||
kv_cache=kv_cache_permute,
|
kv_cache=kv_cache_permute,
|
||||||
@ -881,7 +916,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
bmm2_scale=self.bmm2_scale,
|
bmm2_scale=self.bmm2_scale,
|
||||||
window_left=self.window_left,
|
window_left=self.window_left,
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
out=output[:num_decode_tokens],
|
o_sf_scale=self.o_sf_scale,
|
||||||
|
out=out,
|
||||||
)
|
)
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|
||||||
|
|||||||
@ -428,6 +428,7 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: FlexAttentionMetadata,
|
attn_metadata: FlexAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FLexAttention.
|
"""Forward pass with FLexAttention.
|
||||||
|
|
||||||
@ -441,7 +442,7 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for FlexAttentionImpl")
|
" for FlexAttentionImpl")
|
||||||
|
|||||||
@ -1138,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
attn_metadata: M,
|
attn_metadata: M,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for MLACommonImpl")
|
" for MLACommonImpl")
|
||||||
|
|||||||
@ -227,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata: PallasMetadata,
|
attn_metadata: PallasMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Pallas attention.
|
"""Forward pass with Pallas attention.
|
||||||
|
|
||||||
@ -239,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for PallasAttentionBackendImpl")
|
" for PallasAttentionBackendImpl")
|
||||||
|
|||||||
@ -421,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: AiterFlashAttentionMetadata,
|
attn_metadata: AiterFlashAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with AiterFlashAttention.
|
"""Forward pass with AiterFlashAttention.
|
||||||
|
|
||||||
@ -438,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for FlashAttentionImpl")
|
" for FlashAttentionImpl")
|
||||||
|
|||||||
@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: TreeAttentionMetadata,
|
attn_metadata: TreeAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with TreeAttention.
|
"""Forward pass with TreeAttention.
|
||||||
|
|
||||||
@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl):
|
|||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for TreeAttentionImpl")
|
" for TreeAttentionImpl")
|
||||||
|
|||||||
@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for TritonAttentionImpl")
|
" for TritonAttentionImpl")
|
||||||
|
|||||||
@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: XFormersAttentionMetadata,
|
attn_metadata: XFormersAttentionMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with XFormers.
|
"""Forward pass with XFormers.
|
||||||
|
|
||||||
@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl):
|
|||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None or output_block_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
" for XFormersAttentionImpl")
|
" for XFormersAttentionImpl")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user