Add FlashInfer allreduce RMSNorm Quant fusion (#21069)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
Ilya Markov 2025-07-31 22:58:38 +02:00 committed by GitHub
parent 2dff2e21d9
commit 6e672daf62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 606 additions and 68 deletions

View File

@ -353,6 +353,7 @@ steps:
- pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py
- label: PyTorch Fullgraph Smoke Test # 9min - label: PyTorch Fullgraph Smoke Test # 9min
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]

View File

@ -7,22 +7,26 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig) ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment, from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape, QuantFP8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
from ..utils import multi_gpu_test from ..utils import has_module_attribute, multi_gpu_test
from .backend import TestBackend from .backend import TestBackend
class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
@ -43,7 +47,7 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
@ -62,24 +66,101 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(self.output,
norm_output.contiguous(),
self.scale)
return self.output, residual_output
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
]
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
dtype=torch.int32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
self.output_scale, self.scale)
return self.output, residual_output, self.output_scale
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default
]
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize("test_model", [
"test_model", TestAllReduceRMSNormModel,
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel]) TestAllReduceFusedAddRMSNormModel,
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA") reason="Only test on CUDA")
@pytest.mark.skipif(not find_spec("flashinfer"), @pytest.mark.skipif(
reason="flashinfer is not installed") not find_spec("flashinfer")
@pytest.mark.skipif(not current_platform.is_device_capability(100), or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
reason="Only test on SM100") reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion")
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int, batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype): hidden_size: int, dtype: torch.dtype):
num_processes = 2 num_processes = 2
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and not current_platform.has_device_capability(100)):
pytest.skip("Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)")
def run_torch_spawn(fn, nprocs): def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn, torch.multiprocessing.spawn(fn,
@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
vllm_config = VllmConfig( vllm_config = VllmConfig(compilation_config=CompilationConfig(
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"], custom_ops=["+rms_norm", "+quant_fp8"]))
compile_sizes=[2, 4, 8]))
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True) enable_fi_allreduce_fusion=True, enable_noop=True)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
seed=42) seed=42)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass) noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
model = test_model_cls(hidden_size) backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), token_num = batch_size * seq_len
requires_grad=False) model = test_model_cls(hidden_size, token_num)
residual = torch.randn((batch_size * seq_len, hidden_size),
requires_grad=False) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend) compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual) compiled_model(hidden_states, residual)

View File

@ -4,6 +4,7 @@
import asyncio import asyncio
import copy import copy
import functools import functools
import importlib
import os import os
import signal import signal
import subprocess import subprocess
@ -974,3 +975,14 @@ def get_client_text_logprob_generations(
return [(text_generations, text, return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs)) (None if x.logprobs is None else x.logprobs.top_logprobs))
for completion in completions for x in completion.choices] for completion in completions for x in completion.choices]
def has_module_attribute(module_name, attribute_name):
"""
Helper function to check if a module has a specific attribute.
"""
try:
module = importlib.import_module(module_name)
return hasattr(module, attribute_name)
except ImportError:
return False

View File

@ -37,6 +37,8 @@ logger = init_logger(__name__)
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
RMS_OP = torch.ops._C.rms_norm.default RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
class BasePattern: class BasePattern:
@ -394,7 +396,7 @@ if flashinfer_comm is not None:
# Max size of the input tensor per world size # Max size of the input tensor per world size
# to use flashinfer fused allreduce # to use flashinfer fused allreduce
_FI_MAX_SIZES = { _FI_MAX_SIZES = {
2: MiB, # 1MB 2: 64 * MiB, # 64MB
4: MiB, # 1MB 4: MiB, # 1MB
6: MiB // 2, # 512KB 6: MiB // 2, # 512KB
8: MiB // 2, # 512KB 8: MiB // 2, # 512KB
@ -414,9 +416,13 @@ if flashinfer_comm is not None:
trigger_completion_at_end: bool, trigger_completion_at_end: bool,
fp32_acc: bool, fp32_acc: bool,
max_token_num: int, max_token_num: int,
pattern_code: int,
fuse_rms_quant: bool,
norm_out: Optional[torch.Tensor] = None, norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
scale_out: Optional[torch.Tensor] = None,
scale_factor: Optional[torch.Tensor] = None,
) -> None: ) -> None:
num_tokens, hidden_size = allreduce_in.shape num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size() element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size current_tensor_size = num_tokens * hidden_size * element_size
@ -425,7 +431,6 @@ if flashinfer_comm is not None:
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size, max_fusion_size,
) )
if use_flashinfer: if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer" ), "Flashinfer must be enabled when using flashinfer"
@ -455,37 +460,65 @@ if flashinfer_comm is not None:
use_oneshot=True, use_oneshot=True,
trigger_completion_at_end=trigger_completion_at_end, trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc, fp32_acc=fp32_acc,
pattern_code=flashinfer_comm.AllReduceFusionPattern. pattern_code=pattern_code,
kARResidualRMSNorm,
allreduce_out=None, allreduce_out=None,
quant_out=None, quant_out=quant_out,
scale_out=None, scale_out=scale_out,
layout_code=None, # in vllm we only support swizzled layout
scale_factor=None, layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
scale_factor=scale_factor,
) )
else: else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
if norm_out is None: if (scale_factor is not None and scale_out is None
torch.ops._C.fused_add_rms_norm(allreduce_out, residual, and fuse_rms_quant):
rms_gamma, rms_eps) # Do fused rms norm static fp8 quant fused op
if norm_out is None:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
quant_out, allreduce_out, residual, rms_gamma,
scale_factor, rms_eps)
else:
torch.ops._C.rms_norm_static_fp8_quant(
quant_out, allreduce_out, rms_gamma, scale_factor,
rms_eps)
else: else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, if norm_out is None:
rms_eps) torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
allreduce_in.copy_(allreduce_out) rms_gamma, rms_eps)
norm_out = allreduce_out
else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
rms_eps)
if scale_factor is not None:
if scale_out is not None:
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
scale_out, scale_factor)
else:
torch.ops._C.static_scaled_fp8_quant(
quant_out, norm_out, scale_factor)
if scale_factor is None or norm_out is not None:
# we need to return allreduce outpput
# in cases of non quant fused AR + RMS norm
# and fused AR + RMS norm + quant without fused add
allreduce_in.copy_(allreduce_out)
def call_trtllm_fused_allreduce_norm_fake( def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor, allreduce_in: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
rms_gamma: torch.Tensor, rms_gamma: torch.Tensor,
rms_eps: float, rms_eps: float,
world_rank: int, world_rank: int,
world_size: int, world_size: int,
launch_with_pdl: bool, launch_with_pdl: bool,
trigger_completion_at_end: bool, trigger_completion_at_end: bool,
fp32_acc: bool, fp32_acc: bool,
max_token_num: int, max_token_num: int,
norm_out: Optional[torch.Tensor] = None, pattern_code: int,
) -> None: fuse_rms_quant: bool,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
scale_out: Optional[torch.Tensor] = None,
scale_factor: Optional[torch.Tensor] = None) -> None:
pass pass
direct_register_custom_op( direct_register_custom_op(
@ -495,6 +528,8 @@ if flashinfer_comm is not None:
"allreduce_in", "allreduce_in",
"residual", "residual",
"norm_out", "norm_out",
"quant_out",
"scale_out",
], ],
fake_impl=call_trtllm_fused_allreduce_norm_fake, fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
@ -512,6 +547,7 @@ class FlashInferFusedAllReduceParams:
world_size: int, world_size: int,
use_fp32_lamport: bool = False, use_fp32_lamport: bool = False,
max_token_num: int = 1024, max_token_num: int = 1024,
fuse_rms_quant: bool = False,
): ):
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
@ -521,6 +557,7 @@ class FlashInferFusedAllReduceParams:
self.fp32_acc = True self.fp32_acc = True
self.use_oneshot = False self.use_oneshot = False
self.max_token_num = max_token_num self.max_token_num = max_token_num
self.fuse_rms_quant = fuse_rms_quant
def get_trtllm_fused_allreduce_kwargs(self): def get_trtllm_fused_allreduce_kwargs(self):
return { return {
@ -530,10 +567,16 @@ class FlashInferFusedAllReduceParams:
"trigger_completion_at_end": self.trigger_completion_at_end, "trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc, "fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num, "max_token_num": self.max_token_num,
"fuse_rms_quant": self.fuse_rms_quant,
} }
class AllReduceRMSNORMPattern(BasePattern): class AllReduceRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
with fused flashinfer implementation.
Applies to allreduce + rmsnorm before attn in the first Transformer block.
"""
def __init__( def __init__(
self, self,
@ -559,29 +602,34 @@ class AllReduceRMSNORMPattern(BasePattern):
def pattern(input: torch.Tensor, rms_result: torch.Tensor, def pattern(input: torch.Tensor, rms_result: torch.Tensor,
weight: torch.Tensor): weight: torch.Tensor):
all_reduce_output = tensor_model_parallel_all_reduce(input) allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized( rms = auto_functionalized(
RMS_OP, RMS_OP,
result=rms_result, result=rms_result,
input=all_reduce_output, input=allreduce_output,
weight=weight, weight=weight,
epsilon=self.epsilon, epsilon=self.epsilon,
) )
return rms[1], all_reduce_output # rms_result, allreduce_output
return rms[1], allreduce_output
def replacement(input: torch.Tensor, rms_result: torch.Tensor, def replacement(input: torch.Tensor, rms_result: torch.Tensor,
weight: torch.Tensor): weight: torch.Tensor):
residual = torch.zeros_like(input) residual = torch.zeros_like(input)
allreduce = auto_functionalized( allreduce = auto_functionalized(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
residual=residual, residual=residual,
norm_out=rms_result, norm_out=rms_result,
quant_out=None,
scale_out=None,
rms_gamma=weight, rms_gamma=weight,
rms_eps=self.epsilon, rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
) )
# rms_result, allreduce_in
return allreduce[3], allreduce[1] return allreduce[3], allreduce[1]
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.register_replacement(pattern, replacement, self.get_inputs(),
@ -589,6 +637,11 @@ class AllReduceRMSNORMPattern(BasePattern):
class AllReduceFusedAddRMSNormPattern(BasePattern): class AllReduceFusedAddRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
"""
def __init__( def __init__(
self, self,
@ -615,33 +668,390 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
def pattern(residual: torch.Tensor, input: torch.Tensor, def pattern(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor): weight: torch.Tensor):
all_reduce_output = tensor_model_parallel_all_reduce(input) allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized( rms = auto_functionalized(
RMS_ADD_OP, RMS_ADD_OP,
input=all_reduce_output, input=allreduce_output,
residual=residual, residual=residual,
weight=weight, weight=weight,
epsilon=self.epsilon, epsilon=self.epsilon,
) )
# input, residual
return rms[1], rms[2] return rms[1], rms[2]
def replacement(residual: torch.Tensor, input: torch.Tensor, def replacement(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor): weight: torch.Tensor):
allreduce = auto_functionalized( allreduce = auto_functionalized(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
residual=residual, residual=residual,
norm_out=None,
quant_out=None,
scale_out=None,
rms_gamma=weight, rms_gamma=weight,
rms_eps=self.epsilon, rms_eps=self.epsilon,
norm_out=None, pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
) )
# allreduce_in, residual
return allreduce[1], allreduce[2] return allreduce[1], allreduce[2]
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass) pm.fwd_only, pm_pass)
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
allreduce_params: FlashInferFusedAllReduceParams):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.zeros([1, 8, 4],
device=self.device,
dtype=self.dtype)
rmsnorm_result = torch.empty([1, 8, 4],
device=self.device,
dtype=self.dtype)
quant_result = torch.empty([1, 8, 4],
device=self.device,
dtype=self.quant_dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon)
quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP,
result=quant_result,
input=rmsnorm_out_tuple[1],
scale=scale)
# quant_out, allreduce_output
return quant_out_tuple[1], all_reduce
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
quant_result: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
residual = torch.zeros_like(input)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output
return allreduce[4], allreduce[1]
pm.register_replacement(pattern, replacement, get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
allreduce_params: FlashInferFusedAllReduceParams):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
quant_result = torch.empty([4, 4],
device=self.device,
dtype=self.quant_dtype)
scale = torch.empty([1, 1],
device=self.device,
dtype=torch.float32)
return [
quant_result,
residual,
input,
weight,
scale,
]
def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
allreduce_output = tensor_model_parallel_all_reduce(input)
fused_add_rmsnorm_out_tuple = \
auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon)
quant_out_tuple = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale)
# quant_out, allreduce_output
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# # quant_out, rms_norm_residual
return allreduce[4], allreduce[2]
pm.register_replacement(pattern, replacement, get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
allreduce_params: FlashInferFusedAllReduceParams):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([1, 16, 16],
device=self.device,
dtype=self.dtype)
rmsnorm_result = torch.empty([1, 16, 16],
device=self.device,
dtype=self.dtype)
quant_result = torch.empty((16, 8),
device=self.device,
dtype=torch.uint8)
input_global_scale = torch.empty([1, 1],
device=self.device,
dtype=torch.float32)
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4],
device=self.device,
dtype=torch.int32)
return [
input, rmsnorm_result, quant_result, weight,
input_global_scale, output_scale
]
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
):
all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rmsnorm_out_tuple[1],
output_scale=output_scale,
input_scale=input_global_scale)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
quant_result: torch.Tensor, weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor):
residual = torch.zeros_like(input)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output, output_scale
return allreduce[4], allreduce[1], allreduce[5]
pm.register_replacement(pattern, replacement, get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
allreduce_params: FlashInferFusedAllReduceParams):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
residual = torch.empty([16, 16],
device=self.device,
dtype=self.dtype)
weight = torch.empty([16, 16],
device=self.device,
dtype=self.dtype)
quant_result = torch.empty((16, 8),
device=self.device,
dtype=torch.uint8)
input_global_scale = torch.empty([1, 1],
device=self.device,
dtype=torch.float32)
output_scale = torch.empty([128, 4],
device=self.device,
dtype=torch.int32)
return [
quant_result,
residual,
input,
output_scale,
weight,
input_global_scale,
]
def pattern(quant_result: torch.Tensor, residual: torch.Tensor,
input: torch.Tensor, output_scale: torch.Tensor,
weight: torch.Tensor, input_global_scale: torch.Tensor):
allreduce_output = tensor_model_parallel_all_reduce(input)
fused_add_rmsnorm_out_tuple = \
auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
output_scale=output_scale,
input_scale=input_global_scale)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[
2], quant_out_tuple[2]
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
input: torch.Tensor, output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor):
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, rms_norm_residual, output_scale
return allreduce[4], allreduce[2], allreduce[5]
pm.register_replacement(pattern, replacement, get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusionPass(VllmInductorPass): class AllReduceFusionPass(VllmInductorPass):
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
@ -671,13 +1081,16 @@ class AllReduceFusionPass(VllmInductorPass):
self.tp_size, self.tp_size,
) )
return return
max_num_token = min(
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
(self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num)
self.ipc_handles, workspace_tensor = ( self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank, tp_rank=rank,
tp_size=self.tp_size, tp_size=self.tp_size,
max_token_num=config.compilation_config.pass_config. max_token_num=max_num_token,
fi_allreduce_fusion_max_token_num,
hidden_dim=self.hidden_dim, hidden_dim=self.hidden_dim,
group=self.group, group=self.group,
use_fp32_lamport=use_fp32_lamport, use_fp32_lamport=use_fp32_lamport,
@ -689,12 +1102,38 @@ class AllReduceFusionPass(VllmInductorPass):
rank=rank, rank=rank,
world_size=self.tp_size, world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport, use_fp32_lamport=use_fp32_lamport,
max_token_num=config.compilation_config.pass_config. max_token_num=max_num_token,
fi_allreduce_fusion_max_token_num, # fuse rms norm static fp8 quant fused op
) # in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
AllReduceRMSNORMPattern( AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
if current_platform.has_device_capability(100):
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceRMSNormPattern(
epsilon, epsilon,
self.model_dtype, self.model_dtype,
self.device, self.device,
@ -707,6 +1146,10 @@ class AllReduceFusionPass(VllmInductorPass):
self.allreduce_params, self.allreduce_params,
).register(self.patterns) ).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
self.disabled = False self.disabled = False
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
@ -723,5 +1166,5 @@ class AllReduceFusionPass(VllmInductorPass):
if self.disabled: if self.disabled:
return return
if flashinfer_comm is not None: if flashinfer_comm is not None:
flashinfer_comm.trtllm_destroy_ipc_workspace( flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
self.ipc_handles, self.group) self.ipc_handles, self.group)

View File

@ -4051,7 +4051,7 @@ class PassConfig:
"""Whether to enable async TP.""" """Whether to enable async TP."""
enable_fi_allreduce_fusion: bool = False enable_fi_allreduce_fusion: bool = False
"""Whether to enable flashinfer allreduce fusion.""" """Whether to enable flashinfer allreduce fusion."""
fi_allreduce_fusion_max_token_num: int = 1024 fi_allreduce_fusion_max_token_num: int = 16384
"""Max number of tokens to used in flashinfer allreduce fusion.""" """Max number of tokens to used in flashinfer allreduce fusion."""
# TODO(luka) better pass enabling system. # TODO(luka) better pass enabling system.