From 6e672daf62e7b03ff1dcf74e4206dad07d39d4ec Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Thu, 31 Jul 2025 22:58:38 +0200 Subject: [PATCH] Add FlashInfer allreduce RMSNorm Quant fusion (#21069) Signed-off-by: ilmarkov Signed-off-by: ilmarkov Co-authored-by: ilmarkov --- .buildkite/test-pipeline.yaml | 1 + tests/compile/test_fusion_all_reduce.py | 126 +++++- tests/utils.py | 12 + vllm/compilation/collective_fusion.py | 533 ++++++++++++++++++++++-- vllm/config.py | 2 +- 5 files changed, 606 insertions(+), 68 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a7fe200559305..2f6cc45be77e6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -353,6 +353,7 @@ steps: - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py + - pytest -v -s compile/test_fusion_all_reduce.py - label: PyTorch Fullgraph Smoke Test # 9min mirror_hardwares: [amdexperimental] diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index b8d64247f6beb..b394e0035c689 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -7,22 +7,26 @@ import torch import vllm.envs as envs from vllm.compilation.collective_fusion import AllReduceFusionPass +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, ModelConfig, PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + GroupShape, QuantFP8) from vllm.platforms import current_platform from vllm.utils import update_environment_variables -from ..utils import multi_gpu_test +from ..utils import has_module_attribute, multi_gpu_test from .backend import TestBackend class TestAllReduceRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, eps=1e-6): + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps @@ -43,7 +47,7 @@ class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, eps=1e-6): + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps @@ -62,24 +66,101 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] +class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): + + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm = RMSNorm(hidden_size, eps) + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.scale = torch.rand(1, dtype=torch.float32) + self.output = torch.empty((token_num, hidden_size), + dtype=torch.float32) + + def forward(self, hidden_states, residual): + view = hidden_states.reshape(-1, self.hidden_size) + all_reduce = tensor_model_parallel_all_reduce(view) + norm_output, residual_output = self.norm(all_reduce, residual) + torch.ops._C.static_scaled_fp8_quant(self.output, + norm_output.contiguous(), + self.scale) + return self.output, residual_output + + def ops_in_model_after(self): + return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + + def ops_in_model_before(self): + return [ + torch.ops.vllm.all_reduce.default, + torch.ops._C.static_scaled_fp8_quant.default + ] + + +class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): + + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm = RMSNorm(hidden_size, eps) + self.scale = torch.rand(1, dtype=torch.float32) + self.output = torch.empty((token_num, hidden_size), + dtype=torch.float32) + + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(token_num, 128) + scale_n = hidden_size // 16 + rounded_n = round_up(scale_n, 4) + self.output_scale = torch.empty((rounded_m, rounded_n // 4), + dtype=torch.int32) + + def forward(self, hidden_states, residual): + view = hidden_states.reshape(-1, self.hidden_size) + all_reduce = tensor_model_parallel_all_reduce(view) + norm_output, residual_output = self.norm(all_reduce, residual) + norm_output = norm_output.reshape(-1, norm_output.shape[-1]) + torch.ops._C.scaled_fp4_quant(self.output, norm_output, + self.output_scale, self.scale) + return self.output, residual_output, self.output_scale + + def ops_in_model_after(self): + return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + + def ops_in_model_before(self): + return [ + torch.ops.vllm.all_reduce.default, + torch.ops._C.scaled_fp4_quant.default + ] + + @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "test_model", - [TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel]) +@pytest.mark.parametrize("test_model", [ + TestAllReduceRMSNormModel, + TestAllReduceFusedAddRMSNormModel, + TestAllReduceFusedAddRMSNormStaticQuantFP8Model, + TestAllReduceFusedAddRMSNormStaticQuantFP4Model, +]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) -@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") -@pytest.mark.skipif(not find_spec("flashinfer"), - reason="flashinfer is not installed") -@pytest.mark.skipif(not current_platform.is_device_capability(100), - reason="Only test on SM100") +@pytest.mark.skipif( + not find_spec("flashinfer") + or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"), + reason="flashinfer is not found or flashinfer " + "is not compiled with trtllm_allreduce_fusion") def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): num_processes = 2 + if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model + and not current_platform.has_device_capability(100)): + pytest.skip("Skip as nvfp4 is only supported on " + "devices with compute capability 10.0 (Blackwell)") def run_torch_spawn(fn, nprocs): torch.multiprocessing.spawn(fn, @@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) - vllm_config = VllmConfig( - compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm"], - compile_sizes=[2, 4, 8])) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"])) vllm_config.compilation_config.pass_config = PassConfig( - enable_fi_allreduce_fusion=True) + enable_fi_allreduce_fusion=True, enable_noop=True) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config @@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, seed=42) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) - backend = TestBackend(all_reduce_fusion_pass) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) - model = test_model_cls(hidden_size) + backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - requires_grad=False) - residual = torch.randn((batch_size * seq_len, hidden_size), - requires_grad=False) + token_num = batch_size * seq_len + model = test_model_cls(hidden_size, token_num) + + hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) + residual = torch.randn((token_num, hidden_size), requires_grad=False) compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states, residual) diff --git a/tests/utils.py b/tests/utils.py index f4317e6bdb406..1c1a1cc6014ec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ import asyncio import copy import functools +import importlib import os import signal import subprocess @@ -974,3 +975,14 @@ def get_client_text_logprob_generations( return [(text_generations, text, (None if x.logprobs is None else x.logprobs.top_logprobs)) for completion in completions for x in completion.choices] + + +def has_module_attribute(module_name, attribute_name): + """ + Helper function to check if a module has a specific attribute. + """ + try: + module = importlib.import_module(module_name) + return hasattr(module, attribute_name) + except ImportError: + return False diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index cb99fe8310e73..6ae50245ed3a8 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -37,6 +37,8 @@ logger = init_logger(__name__) ALLREDUCE_OP = torch.ops.vllm.all_reduce.default RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default +STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default class BasePattern: @@ -394,7 +396,7 @@ if flashinfer_comm is not None: # Max size of the input tensor per world size # to use flashinfer fused allreduce _FI_MAX_SIZES = { - 2: MiB, # 1MB + 2: 64 * MiB, # 64MB 4: MiB, # 1MB 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB @@ -414,9 +416,13 @@ if flashinfer_comm is not None: trigger_completion_at_end: bool, fp32_acc: bool, max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + scale_factor: Optional[torch.Tensor] = None, ) -> None: - num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size @@ -425,7 +431,6 @@ if flashinfer_comm is not None: _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), max_fusion_size, ) - if use_flashinfer: assert (_FI_WORKSPACE_TENSOR is not None ), "Flashinfer must be enabled when using flashinfer" @@ -455,37 +460,65 @@ if flashinfer_comm is not None: use_oneshot=True, trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNorm, + pattern_code=pattern_code, allreduce_out=None, - quant_out=None, - scale_out=None, - layout_code=None, - scale_factor=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED, + scale_factor=scale_factor, ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if norm_out is None: - torch.ops._C.fused_add_rms_norm(allreduce_out, residual, - rms_gamma, rms_eps) + if (scale_factor is not None and scale_out is None + and fuse_rms_quant): + # Do fused rms norm static fp8 quant fused op + if norm_out is None: + torch.ops._C.fused_add_rms_norm_static_fp8_quant( + quant_out, allreduce_out, residual, rms_gamma, + scale_factor, rms_eps) + else: + torch.ops._C.rms_norm_static_fp8_quant( + quant_out, allreduce_out, rms_gamma, scale_factor, + rms_eps) else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, - rms_eps) - allreduce_in.copy_(allreduce_out) + if norm_out is None: + torch.ops._C.fused_add_rms_norm(allreduce_out, residual, + rms_gamma, rms_eps) + norm_out = allreduce_out + else: + torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, + rms_eps) + if scale_factor is not None: + if scale_out is not None: + torch.ops._C.scaled_fp4_quant(quant_out, norm_out, + scale_out, scale_factor) + else: + torch.ops._C.static_scaled_fp8_quant( + quant_out, norm_out, scale_factor) + if scale_factor is None or norm_out is not None: + # we need to return allreduce outpput + # in cases of non quant fused AR + RMS norm + # and fused AR + RMS norm + quant without fused add + allreduce_in.copy_(allreduce_out) def call_trtllm_fused_allreduce_norm_fake( - allreduce_in: torch.Tensor, - residual: torch.Tensor, - rms_gamma: torch.Tensor, - rms_eps: float, - world_rank: int, - world_size: int, - launch_with_pdl: bool, - trigger_completion_at_end: bool, - fp32_acc: bool, - max_token_num: int, - norm_out: Optional[torch.Tensor] = None, - ) -> None: + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + scale_factor: Optional[torch.Tensor] = None) -> None: pass direct_register_custom_op( @@ -495,6 +528,8 @@ if flashinfer_comm is not None: "allreduce_in", "residual", "norm_out", + "quant_out", + "scale_out", ], fake_impl=call_trtllm_fused_allreduce_norm_fake, dispatch_key=current_platform.dispatch_key, @@ -512,6 +547,7 @@ class FlashInferFusedAllReduceParams: world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, + fuse_rms_quant: bool = False, ): self.rank = rank self.world_size = world_size @@ -521,6 +557,7 @@ class FlashInferFusedAllReduceParams: self.fp32_acc = True self.use_oneshot = False self.max_token_num = max_token_num + self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -530,10 +567,16 @@ class FlashInferFusedAllReduceParams: "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, + "fuse_rms_quant": self.fuse_rms_quant, } -class AllReduceRMSNORMPattern(BasePattern): +class AllReduceRMSNormPattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (without residual) + with fused flashinfer implementation. + Applies to allreduce + rmsnorm before attn in the first Transformer block. + """ def __init__( self, @@ -559,29 +602,34 @@ class AllReduceRMSNORMPattern(BasePattern): def pattern(input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor): - all_reduce_output = tensor_model_parallel_all_reduce(input) + allreduce_output = tensor_model_parallel_all_reduce(input) rms = auto_functionalized( RMS_OP, result=rms_result, - input=all_reduce_output, + input=allreduce_output, weight=weight, epsilon=self.epsilon, ) - return rms[1], all_reduce_output + # rms_result, allreduce_output + return rms[1], allreduce_output def replacement(input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor): residual = torch.zeros_like(input) allreduce = auto_functionalized( - torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, + flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=rms_result, + quant_out=None, + scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNorm, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) - + # rms_result, allreduce_in return allreduce[3], allreduce[1] pm.register_replacement(pattern, replacement, self.get_inputs(), @@ -589,6 +637,11 @@ class AllReduceRMSNORMPattern(BasePattern): class AllReduceFusedAddRMSNormPattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (with residual) + with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn. + """ def __init__( self, @@ -615,33 +668,390 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): - all_reduce_output = tensor_model_parallel_all_reduce(input) + allreduce_output = tensor_model_parallel_all_reduce(input) rms = auto_functionalized( RMS_ADD_OP, - input=all_reduce_output, + input=allreduce_output, residual=residual, weight=weight, epsilon=self.epsilon, ) + # input, residual return rms[1], rms[2] def replacement(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce = auto_functionalized( - torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, + flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, + norm_out=None, + quant_out=None, + scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - norm_out=None, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNorm, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) + # allreduce_in, residual return allreduce[1], allreduce[2] pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) +class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (without residual) + + static fp8 quant with fused flashinfer implementation. + Applies to allreduce + rmsnorm + quant before attn + in the first Transformer block. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.quant_dtype = torch.float8_e4m3fn + + def register(self, pm_pass: PatternMatcherPass): + + def get_inputs(): + input = torch.zeros([1, 8, 4], + device=self.device, + dtype=self.dtype) + rmsnorm_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.quant_dtype) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [input, rmsnorm_result, quant_result, weight, scale] + + def pattern( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + all_reduce = tensor_model_parallel_all_reduce(input) + rmsnorm_out_tuple = auto_functionalized(RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon) + + quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, + result=quant_result, + input=rmsnorm_out_tuple[1], + scale=scale) + + # quant_out, allreduce_output + return quant_out_tuple[1], all_reduce + + def replacement(input: torch.Tensor, result_rms: torch.Tensor, + quant_result: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + residual = torch.zeros_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=result_rms, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + # quant_out, allreduce_output + return allreduce[4], allreduce[1] + + pm.register_replacement(pattern, replacement, get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (with residual) + + static fp8 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.quant_dtype = torch.float8_e4m3fn + + def register(self, pm_pass: PatternMatcherPass): + + def get_inputs(): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + quant_result = torch.empty([4, 4], + device=self.device, + dtype=self.quant_dtype) + scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + + return [ + quant_result, + residual, + input, + weight, + scale, + ] + + def pattern( + quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + allreduce_output = tensor_model_parallel_all_reduce(input) + + fused_add_rmsnorm_out_tuple = \ + auto_functionalized( + RMS_ADD_OP, + input=allreduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon) + quant_out_tuple = auto_functionalized( + STATIC_FP8_QUANT_OP, + result=quant_result, + input=fused_add_rmsnorm_out_tuple[1], + scale=scale) + + # quant_out, allreduce_output + return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + + def replacement(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # # quant_out, rms_norm_residual + return allreduce[4], allreduce[2] + + pm.register_replacement(pattern, replacement, get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (without residual) + + static nvfp4 quant with fused flashinfer implementation. + Applies to allreduce + rmsnorm + quant before attn + in the first Transformer block. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def register(self, pm_pass: PatternMatcherPass): + + def get_inputs(): + input = torch.empty([1, 16, 16], + device=self.device, + dtype=self.dtype) + + rmsnorm_result = torch.empty([1, 16, 16], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty((16, 8), + device=self.device, + dtype=torch.uint8) + input_global_scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + weight = torch.empty([16], device=self.device, dtype=self.dtype) + output_scale = torch.empty([128, 4], + device=self.device, + dtype=torch.int32) + + return [ + input, rmsnorm_result, quant_result, weight, + input_global_scale, output_scale + ] + + def pattern( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + output_scale: torch.Tensor, + ): + all_reduce = tensor_model_parallel_all_reduce(input) + rmsnorm_out_tuple = auto_functionalized(RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon) + + quant_out_tuple = auto_functionalized( + STATIC_FP4_QUANT_OP, + output=quant_result, + input=rmsnorm_out_tuple[1], + output_scale=output_scale, + input_scale=input_global_scale) + + # quant_out, allreduce_output, output_scale + return quant_out_tuple[1], all_reduce, quant_out_tuple[2] + + def replacement(input: torch.Tensor, result_rms: torch.Tensor, + quant_result: torch.Tensor, weight: torch.Tensor, + input_global_scale: torch.Tensor, + output_scale: torch.Tensor): + residual = torch.zeros_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=result_rms, + quant_out=quant_result, + scale_out=output_scale, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + scale_factor=input_global_scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + # quant_out, allreduce_output, output_scale + return allreduce[4], allreduce[1], allreduce[5] + + pm.register_replacement(pattern, replacement, get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (with residual) + + static nvfp4 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def register(self, pm_pass: PatternMatcherPass): + + def get_inputs(): + input = torch.empty([16, 16], device=self.device, dtype=self.dtype) + + residual = torch.empty([16, 16], + device=self.device, + dtype=self.dtype) + weight = torch.empty([16, 16], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty((16, 8), + device=self.device, + dtype=torch.uint8) + input_global_scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + output_scale = torch.empty([128, 4], + device=self.device, + dtype=torch.int32) + + return [ + quant_result, + residual, + input, + output_scale, + weight, + input_global_scale, + ] + + def pattern(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, output_scale: torch.Tensor, + weight: torch.Tensor, input_global_scale: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + + fused_add_rmsnorm_out_tuple = \ + auto_functionalized( + RMS_ADD_OP, + input=allreduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon) + quant_out_tuple = auto_functionalized( + STATIC_FP4_QUANT_OP, + output=quant_result, + input=fused_add_rmsnorm_out_tuple[1], + output_scale=output_scale, + input_scale=input_global_scale) + + # quant_out, allreduce_output, output_scale + return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[ + 2], quant_out_tuple[2] + + def replacement(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, output_scale: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=quant_result, + scale_out=output_scale, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + scale_factor=input_global_scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # quant_out, rms_norm_residual, output_scale + return allreduce[4], allreduce[2], allreduce[5] + + pm.register_replacement(pattern, replacement, get_inputs(), + pm.fwd_only, pm_pass) + + class AllReduceFusionPass(VllmInductorPass): def __init__(self, config: VllmConfig): @@ -671,13 +1081,16 @@ class AllReduceFusionPass(VllmInductorPass): self.tp_size, ) return - + max_num_token = min( + _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) // + (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), + config.compilation_config.pass_config. + fi_allreduce_fusion_max_token_num) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num, + max_token_num=max_num_token, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -689,12 +1102,38 @@ class AllReduceFusionPass(VllmInductorPass): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num, - ) + max_token_num=max_num_token, + # fuse rms norm static fp8 quant fused op + # in fallback path, when we don't use flashinfer + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) for epsilon in [1e-5, 1e-6]: - AllReduceRMSNORMPattern( + AllReduceFusedRMSNormStaticQuantFP8Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddRMSNormStaticQuantFP8Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + if current_platform.has_device_capability(100): + AllReduceFusedRMSNormStaticQuantNVFP4Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceRMSNormPattern( epsilon, self.model_dtype, self.device, @@ -707,6 +1146,10 @@ class AllReduceFusionPass(VllmInductorPass): self.allreduce_params, ).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + self.disabled = False def __call__(self, graph: fx.Graph): @@ -723,5 +1166,5 @@ class AllReduceFusionPass(VllmInductorPass): if self.disabled: return if flashinfer_comm is not None: - flashinfer_comm.trtllm_destroy_ipc_workspace( + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( self.ipc_handles, self.group) diff --git a/vllm/config.py b/vllm/config.py index 27dde5f1b1f6f..edad5dd0406bf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4051,7 +4051,7 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 1024 + fi_allreduce_fusion_max_token_num: int = 16384 """Max number of tokens to used in flashinfer allreduce fusion.""" # TODO(luka) better pass enabling system.