diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 1fcafe1840cd3..bd326f1157d8f 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -27,6 +27,7 @@ is_blackwell = lambda: current_platform.is_device_capability_family(100) class Matches(NamedTuple): attention_fusion: int = 0 allreduce_fusion: int = 0 + rms_quant_norm_fusion: int = 0 sequence_parallel: int = 0 async_tp: int = 0 @@ -40,6 +41,7 @@ class ModelBackendTestCase(NamedTuple): MODELS_FP8: list[ModelBackendTestCase] = [] MODELS_FP4: list[ModelBackendTestCase] = [] +MODELS_GROUP_FP8: list[ModelBackendTestCase] = [] MODELS: list[ModelBackendTestCase] = [] # tp-only if current_platform.is_cuda(): @@ -498,3 +500,79 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg compilation_config.compile_ranges_split_points = ( llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points ) + + +if current_platform.is_cuda(): + MODELS_GROUP_FP8 = [ + ModelBackendTestCase( + model_name="Qwen/Qwen3-30B-A3B-FP8", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=AttentionBackendEnum.TRITON_ATTN, + matches=Matches( + rms_quant_norm_fusion=48, + ), + ), + ] + +CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"] + + +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, matches, custom_ops", + # Test rms norm+group quant_fp8 fusion + list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +def test_rms_group_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: AttentionBackendEnum, + matches: Matches, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common + mode=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig(eliminate_noops=True, enable_fusion=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model(compilation_config, model_name, **model_kwargs) + + log_matches = re.findall( + r"\[fusion.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 1, log_holder.text + assert int(log_matches[0]) == matches.rms_quant_norm_fusion diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index a7e6a69e64c91..d121106334cb9 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kNvfp4Quant, kStaticTensorScale, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_block_fp8_supported, -) from vllm.platforms import current_platform -from vllm.utils.deep_gemm import ( - is_deep_gemm_e8m0_used, - should_use_deepgemm_for_fp8_linear_for_nk, -) from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm +from .matcher_utils import ( + MatcherFusedAddRMSNorm, + MatcherQuantFP8, + MatcherRMSNorm, +) from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { class RMSNormQuantPattern: - def __init__(self, epsilon: float, key: FusedRMSQuantKey): + def __init__( + self, + epsilon: float, + key: FusedRMSQuantKey, + has_col_major_scales: bool = False, + is_e8m0: bool = False, + ): self.epsilon = epsilon self.quant_dtype = key.quant.dtype config = get_current_vllm_config() self.model_dtype = config.model_config.dtype if config.model_config else None - # groupwise FP8 linear uses col major scales if deepgemm and cutlass - using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk( - self.model_dtype, - config.model_config.hf_config.intermediate_size, - config.model_config.hf_config.hidden_size, - ) - use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported() - use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False - assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -142,7 +136,7 @@ class RMSNormQuantPattern: else MatcherFusedAddRMSNorm(epsilon) ) self.quant_matcher = MatcherQuantFP8( - key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0 + key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 ) @@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape, symmetric=True, + has_col_major_scales: bool = False, + is_e8m0: bool = False, ): scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( @@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), ) self.group_shape = group_shape - super().__init__(epsilon, key) + self.has_col_major_scales = has_col_major_scales + self.is_e8m0 = is_e8m0 + super().__init__( + epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 + ) def register(self, pm_pass: PatternMatcherPass): def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): @@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): input = input.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) - scale = self.quant_matcher.make_scale( - input, transposed=self.quant_matcher.use_col_major_scales - ) + scale = self.quant_matcher.make_scale(input, self.has_col_major_scales) at = auto_functionalized( self.FUSED_OP, result=result, @@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): scale_ub=None, residual=residual, group_size=self.group_shape[1], - is_scale_transposed=self.quant_matcher.use_col_major_scales, + is_scale_transposed=self.has_col_major_scales, ) # result, residual, scale @@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape, symmetric=True, + has_col_major_scales: bool = False, + is_e8m0: bool = False, ): scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( @@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), ) self.group_shape = group_shape - super().__init__(epsilon, key) + super().__init__( + epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 + ) def register(self, pm_pass: PatternMatcherPass): def pattern(input: torch.Tensor, weight: torch.Tensor): @@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale( - input, transposed=self.quant_matcher.use_col_major_scales + input, transposed=self.quant_matcher.has_col_major_scales ) at = auto_functionalized( self.FUSED_OP, @@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): scale_ub=None, residual=None, group_size=self.group_shape[1], - is_scale_transposed=self.quant_matcher.use_col_major_scales, + is_scale_transposed=self.quant_matcher.has_col_major_scales, ) # result, scale @@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): # Make sure fused add patterns are before simple rms norm, # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Fuse fused_add_rms_norm + fp8 group quant - # Only register group quant patterns on CUDA where the C++ op exists - if current_platform.is_cuda(): - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) - - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) - - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) - - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns @@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): # Fuse rms_norm + dynamic per-token fp8 quant RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Only register group quant patterns on CUDA where the C++ op exists + if current_platform.is_cuda(): + for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]: + for has_col_major_scales in [True, False]: + for is_e8m0 in [True, False]: + # Fuse fused_add_rms_norm + fp8 group quant + FusedAddRMSNormGroupQuantPattern( + epsilon, + FP8_DTYPE, + group_shape=group_shape, + has_col_major_scales=has_col_major_scales, + is_e8m0=is_e8m0, + ).register(self.patterns) + + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, + FP8_DTYPE, + group_shape=group_shape, + has_col_major_scales=has_col_major_scales, + is_e8m0=is_e8m0, + ).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 0c0bece9b3fda..ec9ed34f561b4 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp): self, quant_key: QuantKey, enabled: bool | None = None, - use_col_major_scales: bool = False, - use_e8m0: bool = False, + has_col_major_scales: bool = False, + is_e8m0: bool = False, ): if enabled is None: enabled = QuantFP8.enabled() super().__init__(enabled) self.quant_key = quant_key - self.use_col_major_scales = use_col_major_scales - self.use_e8m0 = use_e8m0 assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] + self.has_col_major_scales = has_col_major_scales + self.is_e8m0 = is_e8m0 + assert quant_key.dtype == current_platform.fp8_dtype(), ( "Only QuantFP8 supported by" ) assert quant_key.scale2 is None - self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) + self.quant_fp8 = QuantFP8( + quant_key.scale.static, + quant_key.scale.group_shape, + column_major_scales=has_col_major_scales, + use_ue8m0=is_e8m0, + ) def forward_custom( self, @@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp): if self.quant_key.scale.group_shape.is_per_group(): assert scale is None - scale = self.make_scale(input, transposed=self.use_col_major_scales) + scale = self.make_scale(input, transposed=self.has_col_major_scales) finfo = torch.finfo(self.quant_key.dtype) fp8_min = finfo.min @@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp): eps=1e-10, fp8_min=fp8_min, fp8_max=fp8_max, - scale_ue8m0=self.use_e8m0, + scale_ue8m0=self.is_e8m0, ) return result, scale diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 46be3e2cd5c54..3d4f8449ad3b6 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -381,22 +381,6 @@ def should_use_deepgemm_for_fp8_linear( ) -def should_use_deepgemm_for_fp8_linear_for_nk( - output_dtype: torch.dtype, - shape0: int, - shape1: int, - supports_deep_gemm: bool | None = None, -): - if supports_deep_gemm is None: - supports_deep_gemm = is_deep_gemm_supported() - return ( - supports_deep_gemm - and output_dtype == torch.bfloat16 - and shape0 % 128 == 0 - and shape1 % 128 == 0 - ) - - __all__ = [ "calc_diff", "DeepGemmQuantScaleFMT", @@ -411,7 +395,6 @@ __all__ = [ "is_deep_gemm_supported", "get_num_sms", "should_use_deepgemm_for_fp8_linear", - "should_use_deepgemm_for_fp8_linear_for_nk", "get_col_major_tma_aligned_tensor", "get_mk_alignment_for_contiguous_layout", ]