[Bugfix] Fix fusion for VL models (#30244)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
ElizaWszola 2025-12-14 14:22:37 +01:00 committed by GitHub
parent 48b8456ff9
commit 994acec0cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 143 additions and 72 deletions

View File

@ -27,6 +27,7 @@ is_blackwell = lambda: current_platform.is_device_capability_family(100)
class Matches(NamedTuple):
attention_fusion: int = 0
allreduce_fusion: int = 0
rms_quant_norm_fusion: int = 0
sequence_parallel: int = 0
async_tp: int = 0
@ -40,6 +41,7 @@ class ModelBackendTestCase(NamedTuple):
MODELS_FP8: list[ModelBackendTestCase] = []
MODELS_FP4: list[ModelBackendTestCase] = []
MODELS_GROUP_FP8: list[ModelBackendTestCase] = []
MODELS: list[ModelBackendTestCase] = [] # tp-only
if current_platform.is_cuda():
@ -498,3 +500,79 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
compilation_config.compile_ranges_split_points = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
)
if current_platform.is_cuda():
MODELS_GROUP_FP8 = [
ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
rms_quant_norm_fusion=48,
),
),
]
CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Test rms norm+group quant_fp8 fusion
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
def test_rms_group_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: AttentionBackendEnum,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
custom_ops=custom_ops_list,
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
splitting_ops=splitting_ops,
# Common
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)
log_matches = re.findall(
r"\[fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 1, log_holder.text
assert int(log_matches[0]) == matches.rms_quant_norm_fusion

View File

@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear_for_nk,
)
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
def __init__(
self,
epsilon: float,
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
self.model_dtype,
config.model_config.hf_config.intermediate_size,
config.model_config.hf_config.hidden_size,
)
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
@ -142,7 +136,7 @@ class RMSNormQuantPattern:
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
scale = self.quant_matcher.make_scale(input, self.has_col_major_scales)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
is_scale_transposed=self.has_col_major_scales,
)
# result, residual, scale
@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
input, transposed=self.quant_matcher.has_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
is_scale_transposed=self.quant_matcher.has_col_major_scales,
)
# result, scale
@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log

View File

@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp):
self,
quant_key: QuantKey,
enabled: bool | None = None,
use_col_major_scales: bool = False,
use_e8m0: bool = False,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.use_col_major_scales = use_col_major_scales
self.use_e8m0 = use_e8m0
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
self.quant_fp8 = QuantFP8(
quant_key.scale.static,
quant_key.scale.group_shape,
column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0,
)
def forward_custom(
self,
@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp):
if self.quant_key.scale.group_shape.is_per_group():
assert scale is None
scale = self.make_scale(input, transposed=self.use_col_major_scales)
scale = self.make_scale(input, transposed=self.has_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp):
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.use_e8m0,
scale_ue8m0=self.is_e8m0,
)
return result, scale

View File

@ -381,22 +381,6 @@ def should_use_deepgemm_for_fp8_linear(
)
def should_use_deepgemm_for_fp8_linear_for_nk(
output_dtype: torch.dtype,
shape0: int,
shape1: int,
supports_deep_gemm: bool | None = None,
):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
and shape0 % 128 == 0
and shape1 % 128 == 0
)
__all__ = [
"calc_diff",
"DeepGemmQuantScaleFMT",
@ -411,7 +395,6 @@ __all__ = [
"is_deep_gemm_supported",
"get_num_sms",
"should_use_deepgemm_for_fp8_linear",
"should_use_deepgemm_for_fp8_linear_for_nk",
"get_col_major_tma_aligned_tensor",
"get_mk_alignment_for_contiguous_layout",
]