[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Luka Govedič 2025-06-12 11:31:04 -04:00 committed by GitHub
parent 96846bb360
commit f98548b9da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 622 additions and 79 deletions

View File

@ -305,6 +305,7 @@ steps:
commands:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_fusion_attn.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py

View File

@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from copy import deepcopy
from typing import Callable, Union
from torch import fx
from torch._ops import OpOverload
from vllm.compilation.fx_utils import (find_specified_fn,
find_specified_fn_maybe)
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config
@ -48,18 +49,19 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph
self.final_graph = graph
def check_before_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe, \
ops_fully_replaced=True):
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
for op in ops:
find_fn(self.graph_pre_pass.nodes, op)
if ops_fully_replaced:
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe):
def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops:
find_fn(self.graph_post_pass.nodes, op)
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"

View File

@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(),
ops_fully_replaced=False)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist

View File

@ -7,8 +7,7 @@ import torch
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
FusionPass, GroupShape, QuantKey)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
@ -30,9 +29,10 @@ class TestModel(torch.nn.Module):
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
group_shape=group_shape,
symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
find_auto_fn_maybe)
backend.check_before_ops(model.ops_in_model_before())
# In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn,
find_auto_fn_maybe)
backend.check_after_ops(model.ops_in_model_after())

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch._dynamo
from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal
from vllm import LLM, SamplingParams
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.platforms import current_platform
# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
# Use global backends
global backend, backend_unfused
use_v1 = False # can be made a param once V1 support added
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
# Prompt 4 seems too open-ended, differs between fused and unfused
# (both outputs look reasonable though)
prompts = example_prompts[:4] + example_prompts[5:]
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
)
vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=10,
top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
del llm
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
)
vllm_config = VllmConfig(compilation_config=compile_config)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape)
for key, layer in compile_config.static_forward_context.items()
]
print(f"{attn_fusion_supported=}")
if any(attn_fusion_supported):
# Check quant ops
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# attention ops present in both, just output_scale param changes
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
assert len(attn_nodes_pre) == len(attn_nodes_post)
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \
f"Node {i} {'' if fused else 'not '} expected " \
f"to have fused output quant"
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
# transform outputs to format expected by check_outputs_equal
sample_outs = lambda s: (list(s.token_ids), s.text)
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
check_outputs_equal(
outputs_0_lst=outs_lst(unfused_output),
outputs_1_lst=outs_lst(fused_output),
name_0="unfused",
name_1="fused",
)
# Clean Dynamo cache to avoid polluting other case(s)
torch._dynamo.reset()
# Reset backend to make sure llm2 gets released
backend = None

View File

@ -1225,6 +1225,7 @@ def scaled_fp8_quant(
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
@ -1256,7 +1257,12 @@ def scaled_fp8_quant(
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:

View File

@ -284,9 +284,25 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor.
:return: is fusion supported for this type of quantization
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@ -300,6 +316,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -374,6 +374,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -388,6 +389,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)

View File

@ -370,6 +370,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
@ -383,6 +385,13 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
(
query,
query_succ,

View File

@ -673,6 +673,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -692,6 +693,11 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert (

View File

@ -975,8 +975,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
# TODO: directly write to output tensor
num_heads: int = self.num_heads
head_size: int = self.head_size

View File

@ -181,6 +181,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -193,6 +194,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

View File

@ -192,6 +192,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
@ -206,6 +207,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for IpexAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.

View File

@ -1319,11 +1319,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")
if attn_metadata.is_profile_run and \
attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size

View File

@ -172,6 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
@ -187,6 +188,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)

View File

@ -38,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool:
@cache
def _get_paged_attn_module() -> PagedAttention:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
@ -598,6 +598,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor
# Only supported in the Triton backend
return False
def forward(
self,
layer: AttentionLayer,
@ -607,6 +616,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -660,6 +670,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None and not self.use_triton_flash_attn:
raise NotImplementedError(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
@ -799,6 +814,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks[0][None]
if attn_masks is not None else None,
full_scales,
output_scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
@ -876,6 +892,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
@ -887,7 +904,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert _PARTITION_SIZE_ROCM % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
@ -921,9 +938,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
output_scale,
)
else:
output[num_prefill_tokens:] = paged_attn.forward_decode(
# PagedAttention does not support fused quant, manually quantize
if output_scale is None:
out_pa = output[num_prefill_tokens:]
else:
out_pa = torch.empty_like(output[num_prefill_tokens:],
dtype=query.dtype)
out_pa[:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
@ -944,6 +969,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer._v_scale,
)
# Manually perform quantization
if output_scale is not None:
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
out_q = output.view(-1, self.num_heads * self.head_size)
ops.scaled_fp8_quant(out_uq,
output_scale,
output=out_q[num_prefill_tokens:])
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)

View File

@ -459,6 +459,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@ -473,6 +474,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")
# For warming-up
if attn_metadata is None:

View File

@ -435,6 +435,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -487,6 +488,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
attn_type = self.attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type

View File

@ -430,6 +430,7 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
@ -444,7 +445,8 @@ def unified_attention_with_output(
value,
kv_cache,
attn_metadata,
output=output)
output=output,
output_scale=output_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
@ -455,6 +457,7 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
return

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, NamedTuple, Optional
from typing import Callable, ClassVar, NamedTuple, Optional
import torch
import torch._inductor.pattern_matcher as pm
@ -34,36 +33,66 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
per_tensor: per-tensor quantization if True, per-token if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
"""
dtype: torch.dtype
static: bool
per_tensor: bool = True
group_shape: GroupShape
symmetric: bool = True
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'per_tensor' if self.per_tensor else 'per_token'},"
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
@ -83,13 +112,13 @@ class FusedRMSQuantKey(NamedTuple):
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8StaticTensorSym, True):
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
}
@ -177,10 +206,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=True,
per_tensor=True,
symmetric=symmetric))
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
@ -233,10 +263,11 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=True,
per_tensor=True,
symmetric=symmetric))
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
@ -323,12 +354,12 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
per_tensor: bool,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=False,
per_tensor=per_tensor,
group_shape=group_shape,
symmetric=symmetric))
super().__init__(epsilon, key)
@ -421,12 +452,12 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
per_tensor: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=False,
per_tensor=per_tensor,
group_shape=group_shape,
symmetric=symmetric))
super().__init__(epsilon, key)
@ -566,16 +597,12 @@ class FusionPass(VllmInductorPass):
self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
per_tensor=False).register(
self.patterns, self.record_match)
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon,
FP8_DTYPE,
per_tensor=False).register(
self.patterns,
self.record_match)
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.

View File

@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern:
def __init__(
self,
layer_name: str,
num_heads: int,
head_size: int,
quant_dtype: torch.dtype,
symmetric=True,
):
self.layer_name = layer_name
self.num_heads = num_heads
self.head_size = head_size
self.quant_dtype = quant_dtype
self.quant_key = QuantKey(dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty_quant(self, *args, **kwargs):
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def register_if_supported(self, pm_pass: PatternMatcherPass,
layer: Attention):
if layer.impl.fused_output_quant_supported(self.quant_dtype,
self.quant_key.static,
self.quant_key.group_shape):
self._register(pm_pass)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_attn,
[-1, self.num_heads, self.head_size])
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
layer_name=self.layer_name,
output_scale=None)
attn_out_view = RESHAPE_OP(at1[1],
[-1, self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
result=output_quant,
input=attn_out_view,
scale=scale)
return at2[1]
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_quant,
[-1, self.num_heads, self.head_size])
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
layer_name=self.layer_name,
output_scale=scale)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads * self.head_size), # attn_output
self.empty_quant(5, self.num_heads *
self.head_size), # quant_output
empty_fp32(1, 1) # scale
]
def wrap_trace_fn(process_fx, trace_fn):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
pm.register_replacement(
pattern, replacement, inputs,
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
class AttnFusionPass(VllmInductorPass):
"""
This pass fuses post-attention quantization onto attention if supported.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
Currently, only static fp8 quant is supported, but patterns could easily be
added for other quant schemes and dtypes. The bigger hurdle for wider
support are attention kernels, which need to support fusing output quant.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.static_fwd_ctx = config.compilation_config.static_forward_context
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
for key, layer in self.static_fwd_ctx.items():
pattern = AttentionStaticQuantPattern(key, layer.num_heads,
layer.head_size,
current_platform.fp8_dtype())
pattern.register_if_supported(self.patterns, layer)
if len(self.static_fwd_ctx) == 0:
logger.warning(
"Attention + quant fusion is enabled, but "
"CompilationConfig.static_forward_context is empty. "
"Cannot access attention layers so no fusion "
"patterns were registered.")
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin()
self.dump_graph(graph, "before_attn_fusion")
count = self.patterns.apply(graph)
logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion")
self.end_and_log()

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
from collections.abc import Iterable
from collections.abc import Iterable, Iterator
from typing import Optional
from torch import fx
@ -14,6 +14,10 @@ def is_func(node: fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
return is_func(node, auto_functionalized) and node.args[0] == op
# Returns the first specified node with the given op (if it exists)
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]:
@ -60,3 +64,21 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret
# An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op:
yield n
# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def get_only_user(node: fx.Node) -> fx.Node:
assert len(node.users) == 1
return next(iter(node.users))

View File

@ -23,7 +23,23 @@ class NoOpEliminationPass(VllmInductorPass):
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Cases handled:
1. A chain of reshapes is equivalent to the last reshape called on the
base tensor (input of the first reshape).
2. A reshape that produces the shape of the input is redundant
3. A slice that produces the shape of the input is redundant
Example graph 1:
mul_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
Can be replaced with:
mul_1: "f16[s0, 4096]" = ...
view_3: "f16[s0, 128, 32]" = ...
Example graph 2:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
@ -34,7 +50,7 @@ class NoOpEliminationPass(VllmInductorPass):
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 2:
Example graph 3:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
@ -58,6 +74,18 @@ class NoOpEliminationPass(VllmInductorPass):
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
# Case 1: rewrite reshape chains to reshapes on the base tensor
input = node.args[0]
# If the input is a reshape, rebind to that node
if is_func(input, torch.ops.aten.reshape.default):
# The new input is guaranteed not to be a reshape,
# because we process nodes in order
node.update_arg(0, input.args[0])
if len(input.users) == 0:
graph.erase_node(input)
count += 1
# Case 2: remove this reshape if it produces the original shape
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):

View File

@ -10,6 +10,7 @@ from .activation_quant_fusion import ActivationQuantFusionPass
from .collective_fusion import AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .fusion_attn import AttnFusionPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass
@ -59,6 +60,9 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):

View File

@ -4,6 +4,7 @@
import time
import torch
from torch._dynamo.utils import lazy_format_graph_code
from vllm.config import PassConfig, VllmConfig
# yapf: disable
@ -34,6 +35,8 @@ class VllmInductorPass(InductorPass):
self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
lazy_format_graph_code(stage, graph.owning_module)
if stage in self.pass_config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1

View File

@ -3804,9 +3804,10 @@ class PassConfig:
its own stages (before, after, maybe in-between)."""
dump_graph_dir: Path = Path(".")
"""Directory to dump the graphs."""
# TODO(luka) better pass enabling system.
enable_fusion: bool = True
"""Whether to enable the custom fusion pass."""
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion: bool = False
"""Whether to enable the custom attention+quant fusion pass."""
enable_noop: bool = True
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism: bool = False
@ -3814,6 +3815,8 @@ class PassConfig:
enable_async_tp: bool = False
"""Whether to enable async TP."""
# TODO(luka) better pass enabling system.
def uuid(self):
"""
Produces a hash unique to the pass configuration.
@ -3821,18 +3824,20 @@ class PassConfig:
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
include = {
"enable_fusion", "enable_noop", "enable_sequence_parallelism",
"enable_async_tp"
}
dict_ = {k: v for k, v in asdict(self).items() if k in include}
exclude = {"dump_graph_stages", "dump_graph_dir"}
dict_ = {k: v for k, v in asdict(self).items() if k not in exclude}
return InductorPass.hash_dict(dict_)
def __post_init__(self) -> None:
if not self.enable_noop and self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work")
if not self.enable_noop:
if self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm/SiluMul + quant (fp8) fusion might not work")
if self.enable_attn_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"Attention + quant (fp8) fusion might not work")
@config

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0

View File

@ -569,6 +569,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -586,6 +587,11 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output

View File

@ -547,6 +547,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashInfer.
@ -561,6 +562,11 @@ class FlashInferImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
if attn_metadata is None:
# Profiling run.
return output

View File

@ -414,6 +414,7 @@ class FlexAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FLexAttention.
@ -427,6 +428,12 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlexAttentionImpl")
enable_gqa = self.num_kv_heads != self.num_heads
if attn_metadata is None:

View File

@ -865,10 +865,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache: torch.Tensor,
attn_metadata: M,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the

View File

@ -161,6 +161,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
@ -173,6 +174,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl")
# For determine_available_memory case.
if kv_cache.numel() == 0:
if output is None:

View File

@ -142,6 +142,7 @@ class TritonAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -156,6 +157,11 @@ class TritonAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TritonAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output