mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:54:58 +08:00
Add FlashInfer allreduce RMSNorm Quant fusion (#21069)
Signed-off-by: ilmarkov <imarkov@redhat.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
parent
2dff2e21d9
commit
6e672daf62
@ -353,6 +353,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||||
- pytest -v -s compile/test_sequence_parallelism.py
|
- pytest -v -s compile/test_sequence_parallelism.py
|
||||||
- pytest -v -s compile/test_async_tp.py
|
- pytest -v -s compile/test_async_tp.py
|
||||||
|
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
|
|||||||
@ -7,22 +7,26 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||||
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
||||||
ModelConfig, PassConfig, VllmConfig)
|
ModelConfig, PassConfig, VllmConfig)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
GroupShape, QuantFP8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import has_module_attribute, multi_gpu_test
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
|
||||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
@ -43,7 +47,7 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
|||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
@ -62,24 +66,101 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
|||||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.eps = eps
|
||||||
|
self.norm = RMSNorm(hidden_size, eps)
|
||||||
|
self.quant_fp8 = QuantFP8(static=True,
|
||||||
|
group_shape=GroupShape.PER_TENSOR)
|
||||||
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
|
self.output = torch.empty((token_num, hidden_size),
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual):
|
||||||
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||||
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
|
torch.ops._C.static_scaled_fp8_quant(self.output,
|
||||||
|
norm_output.contiguous(),
|
||||||
|
self.scale)
|
||||||
|
return self.output, residual_output
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.all_reduce.default,
|
||||||
|
torch.ops._C.static_scaled_fp8_quant.default
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.eps = eps
|
||||||
|
self.norm = RMSNorm(hidden_size, eps)
|
||||||
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
|
self.output = torch.empty((token_num, hidden_size),
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
round_up = lambda x, y: (x + y - 1) // y * y
|
||||||
|
rounded_m = round_up(token_num, 128)
|
||||||
|
scale_n = hidden_size // 16
|
||||||
|
rounded_n = round_up(scale_n, 4)
|
||||||
|
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual):
|
||||||
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||||
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
|
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
||||||
|
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
|
||||||
|
self.output_scale, self.scale)
|
||||||
|
return self.output, residual_output, self.output_scale
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.all_reduce.default,
|
||||||
|
torch.ops._C.scaled_fp4_quant.default
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("test_model", [
|
||||||
"test_model",
|
TestAllReduceRMSNormModel,
|
||||||
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
|
TestAllReduceFusedAddRMSNormModel,
|
||||||
|
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||||
|
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seq_len", [8])
|
@pytest.mark.parametrize("seq_len", [8])
|
||||||
@pytest.mark.parametrize("hidden_size", [4096])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA")
|
||||||
@pytest.mark.skipif(not find_spec("flashinfer"),
|
@pytest.mark.skipif(
|
||||||
reason="flashinfer is not installed")
|
not find_spec("flashinfer")
|
||||||
@pytest.mark.skipif(not current_platform.is_device_capability(100),
|
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||||
reason="Only test on SM100")
|
reason="flashinfer is not found or flashinfer "
|
||||||
|
"is not compiled with trtllm_allreduce_fusion")
|
||||||
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
||||||
batch_size: int, seq_len: int,
|
batch_size: int, seq_len: int,
|
||||||
hidden_size: int, dtype: torch.dtype):
|
hidden_size: int, dtype: torch.dtype):
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
|
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
||||||
|
and not current_platform.has_device_capability(100)):
|
||||||
|
pytest.skip("Skip as nvfp4 is only supported on "
|
||||||
|
"devices with compute capability 10.0 (Blackwell)")
|
||||||
|
|
||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
torch.multiprocessing.spawn(fn,
|
torch.multiprocessing.spawn(fn,
|
||||||
@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
custom_ops=["+rms_norm"],
|
custom_ops=["+rms_norm", "+quant_fp8"]))
|
||||||
compile_sizes=[2, 4, 8]))
|
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
enable_fi_allreduce_fusion=True)
|
enable_fi_allreduce_fusion=True, enable_noop=True)
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
seed=42)
|
seed=42)
|
||||||
|
|
||||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||||
backend = TestBackend(all_reduce_fusion_pass)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size)
|
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
token_num = batch_size * seq_len
|
||||||
requires_grad=False)
|
model = test_model_cls(hidden_size, token_num)
|
||||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
|
||||||
requires_grad=False)
|
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||||
|
residual = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||||
|
|
||||||
compiled_model = torch.compile(model, backend=backend)
|
compiled_model = torch.compile(model, backend=backend)
|
||||||
compiled_model(hidden_states, residual)
|
compiled_model(hidden_states, residual)
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -974,3 +975,14 @@ def get_client_text_logprob_generations(
|
|||||||
return [(text_generations, text,
|
return [(text_generations, text,
|
||||||
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
||||||
for completion in completions for x in completion.choices]
|
for completion in completions for x in completion.choices]
|
||||||
|
|
||||||
|
|
||||||
|
def has_module_attribute(module_name, attribute_name):
|
||||||
|
"""
|
||||||
|
Helper function to check if a module has a specific attribute.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return hasattr(module, attribute_name)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|||||||
@ -37,6 +37,8 @@ logger = init_logger(__name__)
|
|||||||
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
|
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
|
||||||
RMS_OP = torch.ops._C.rms_norm.default
|
RMS_OP = torch.ops._C.rms_norm.default
|
||||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||||
|
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
|
||||||
|
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||||
|
|
||||||
|
|
||||||
class BasePattern:
|
class BasePattern:
|
||||||
@ -394,7 +396,7 @@ if flashinfer_comm is not None:
|
|||||||
# Max size of the input tensor per world size
|
# Max size of the input tensor per world size
|
||||||
# to use flashinfer fused allreduce
|
# to use flashinfer fused allreduce
|
||||||
_FI_MAX_SIZES = {
|
_FI_MAX_SIZES = {
|
||||||
2: MiB, # 1MB
|
2: 64 * MiB, # 64MB
|
||||||
4: MiB, # 1MB
|
4: MiB, # 1MB
|
||||||
6: MiB // 2, # 512KB
|
6: MiB // 2, # 512KB
|
||||||
8: MiB // 2, # 512KB
|
8: MiB // 2, # 512KB
|
||||||
@ -414,9 +416,13 @@ if flashinfer_comm is not None:
|
|||||||
trigger_completion_at_end: bool,
|
trigger_completion_at_end: bool,
|
||||||
fp32_acc: bool,
|
fp32_acc: bool,
|
||||||
max_token_num: int,
|
max_token_num: int,
|
||||||
|
pattern_code: int,
|
||||||
|
fuse_rms_quant: bool,
|
||||||
norm_out: Optional[torch.Tensor] = None,
|
norm_out: Optional[torch.Tensor] = None,
|
||||||
|
quant_out: Optional[torch.Tensor] = None,
|
||||||
|
scale_out: Optional[torch.Tensor] = None,
|
||||||
|
scale_factor: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
num_tokens, hidden_size = allreduce_in.shape
|
num_tokens, hidden_size = allreduce_in.shape
|
||||||
element_size = allreduce_in.element_size()
|
element_size = allreduce_in.element_size()
|
||||||
current_tensor_size = num_tokens * hidden_size * element_size
|
current_tensor_size = num_tokens * hidden_size * element_size
|
||||||
@ -425,7 +431,6 @@ if flashinfer_comm is not None:
|
|||||||
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
|
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
|
||||||
max_fusion_size,
|
max_fusion_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_flashinfer:
|
if use_flashinfer:
|
||||||
assert (_FI_WORKSPACE_TENSOR is not None
|
assert (_FI_WORKSPACE_TENSOR is not None
|
||||||
), "Flashinfer must be enabled when using flashinfer"
|
), "Flashinfer must be enabled when using flashinfer"
|
||||||
@ -455,37 +460,65 @@ if flashinfer_comm is not None:
|
|||||||
use_oneshot=True,
|
use_oneshot=True,
|
||||||
trigger_completion_at_end=trigger_completion_at_end,
|
trigger_completion_at_end=trigger_completion_at_end,
|
||||||
fp32_acc=fp32_acc,
|
fp32_acc=fp32_acc,
|
||||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
pattern_code=pattern_code,
|
||||||
kARResidualRMSNorm,
|
|
||||||
allreduce_out=None,
|
allreduce_out=None,
|
||||||
quant_out=None,
|
quant_out=quant_out,
|
||||||
scale_out=None,
|
scale_out=scale_out,
|
||||||
layout_code=None,
|
# in vllm we only support swizzled layout
|
||||||
scale_factor=None,
|
layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
|
||||||
|
scale_factor=scale_factor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
|
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
|
||||||
if norm_out is None:
|
if (scale_factor is not None and scale_out is None
|
||||||
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
|
and fuse_rms_quant):
|
||||||
rms_gamma, rms_eps)
|
# Do fused rms norm static fp8 quant fused op
|
||||||
|
if norm_out is None:
|
||||||
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||||
|
quant_out, allreduce_out, residual, rms_gamma,
|
||||||
|
scale_factor, rms_eps)
|
||||||
|
else:
|
||||||
|
torch.ops._C.rms_norm_static_fp8_quant(
|
||||||
|
quant_out, allreduce_out, rms_gamma, scale_factor,
|
||||||
|
rms_eps)
|
||||||
else:
|
else:
|
||||||
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
|
if norm_out is None:
|
||||||
rms_eps)
|
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
|
||||||
allreduce_in.copy_(allreduce_out)
|
rms_gamma, rms_eps)
|
||||||
|
norm_out = allreduce_out
|
||||||
|
else:
|
||||||
|
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
|
||||||
|
rms_eps)
|
||||||
|
if scale_factor is not None:
|
||||||
|
if scale_out is not None:
|
||||||
|
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
|
||||||
|
scale_out, scale_factor)
|
||||||
|
else:
|
||||||
|
torch.ops._C.static_scaled_fp8_quant(
|
||||||
|
quant_out, norm_out, scale_factor)
|
||||||
|
if scale_factor is None or norm_out is not None:
|
||||||
|
# we need to return allreduce outpput
|
||||||
|
# in cases of non quant fused AR + RMS norm
|
||||||
|
# and fused AR + RMS norm + quant without fused add
|
||||||
|
allreduce_in.copy_(allreduce_out)
|
||||||
|
|
||||||
def call_trtllm_fused_allreduce_norm_fake(
|
def call_trtllm_fused_allreduce_norm_fake(
|
||||||
allreduce_in: torch.Tensor,
|
allreduce_in: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
rms_gamma: torch.Tensor,
|
rms_gamma: torch.Tensor,
|
||||||
rms_eps: float,
|
rms_eps: float,
|
||||||
world_rank: int,
|
world_rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
launch_with_pdl: bool,
|
launch_with_pdl: bool,
|
||||||
trigger_completion_at_end: bool,
|
trigger_completion_at_end: bool,
|
||||||
fp32_acc: bool,
|
fp32_acc: bool,
|
||||||
max_token_num: int,
|
max_token_num: int,
|
||||||
norm_out: Optional[torch.Tensor] = None,
|
pattern_code: int,
|
||||||
) -> None:
|
fuse_rms_quant: bool,
|
||||||
|
norm_out: Optional[torch.Tensor] = None,
|
||||||
|
quant_out: Optional[torch.Tensor] = None,
|
||||||
|
scale_out: Optional[torch.Tensor] = None,
|
||||||
|
scale_factor: Optional[torch.Tensor] = None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
@ -495,6 +528,8 @@ if flashinfer_comm is not None:
|
|||||||
"allreduce_in",
|
"allreduce_in",
|
||||||
"residual",
|
"residual",
|
||||||
"norm_out",
|
"norm_out",
|
||||||
|
"quant_out",
|
||||||
|
"scale_out",
|
||||||
],
|
],
|
||||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
@ -512,6 +547,7 @@ class FlashInferFusedAllReduceParams:
|
|||||||
world_size: int,
|
world_size: int,
|
||||||
use_fp32_lamport: bool = False,
|
use_fp32_lamport: bool = False,
|
||||||
max_token_num: int = 1024,
|
max_token_num: int = 1024,
|
||||||
|
fuse_rms_quant: bool = False,
|
||||||
):
|
):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
@ -521,6 +557,7 @@ class FlashInferFusedAllReduceParams:
|
|||||||
self.fp32_acc = True
|
self.fp32_acc = True
|
||||||
self.use_oneshot = False
|
self.use_oneshot = False
|
||||||
self.max_token_num = max_token_num
|
self.max_token_num = max_token_num
|
||||||
|
self.fuse_rms_quant = fuse_rms_quant
|
||||||
|
|
||||||
def get_trtllm_fused_allreduce_kwargs(self):
|
def get_trtllm_fused_allreduce_kwargs(self):
|
||||||
return {
|
return {
|
||||||
@ -530,10 +567,16 @@ class FlashInferFusedAllReduceParams:
|
|||||||
"trigger_completion_at_end": self.trigger_completion_at_end,
|
"trigger_completion_at_end": self.trigger_completion_at_end,
|
||||||
"fp32_acc": self.fp32_acc,
|
"fp32_acc": self.fp32_acc,
|
||||||
"max_token_num": self.max_token_num,
|
"max_token_num": self.max_token_num,
|
||||||
|
"fuse_rms_quant": self.fuse_rms_quant,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AllReduceRMSNORMPattern(BasePattern):
|
class AllReduceRMSNormPattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (without residual)
|
||||||
|
with fused flashinfer implementation.
|
||||||
|
Applies to allreduce + rmsnorm before attn in the first Transformer block.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -559,29 +602,34 @@ class AllReduceRMSNORMPattern(BasePattern):
|
|||||||
|
|
||||||
def pattern(input: torch.Tensor, rms_result: torch.Tensor,
|
def pattern(input: torch.Tensor, rms_result: torch.Tensor,
|
||||||
weight: torch.Tensor):
|
weight: torch.Tensor):
|
||||||
all_reduce_output = tensor_model_parallel_all_reduce(input)
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
rms = auto_functionalized(
|
rms = auto_functionalized(
|
||||||
RMS_OP,
|
RMS_OP,
|
||||||
result=rms_result,
|
result=rms_result,
|
||||||
input=all_reduce_output,
|
input=allreduce_output,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
epsilon=self.epsilon,
|
epsilon=self.epsilon,
|
||||||
)
|
)
|
||||||
return rms[1], all_reduce_output
|
# rms_result, allreduce_output
|
||||||
|
return rms[1], allreduce_output
|
||||||
|
|
||||||
def replacement(input: torch.Tensor, rms_result: torch.Tensor,
|
def replacement(input: torch.Tensor, rms_result: torch.Tensor,
|
||||||
weight: torch.Tensor):
|
weight: torch.Tensor):
|
||||||
residual = torch.zeros_like(input)
|
residual = torch.zeros_like(input)
|
||||||
allreduce = auto_functionalized(
|
allreduce = auto_functionalized(
|
||||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
allreduce_in=input,
|
allreduce_in=input,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
norm_out=rms_result,
|
norm_out=rms_result,
|
||||||
|
quant_out=None,
|
||||||
|
scale_out=None,
|
||||||
rms_gamma=weight,
|
rms_gamma=weight,
|
||||||
rms_eps=self.epsilon,
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNorm,
|
||||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
)
|
)
|
||||||
|
# rms_result, allreduce_in
|
||||||
return allreduce[3], allreduce[1]
|
return allreduce[3], allreduce[1]
|
||||||
|
|
||||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
@ -589,6 +637,11 @@ class AllReduceRMSNORMPattern(BasePattern):
|
|||||||
|
|
||||||
|
|
||||||
class AllReduceFusedAddRMSNormPattern(BasePattern):
|
class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (with residual)
|
||||||
|
with fused flashinfer implementation.
|
||||||
|
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -615,33 +668,390 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
|||||||
|
|
||||||
def pattern(residual: torch.Tensor, input: torch.Tensor,
|
def pattern(residual: torch.Tensor, input: torch.Tensor,
|
||||||
weight: torch.Tensor):
|
weight: torch.Tensor):
|
||||||
all_reduce_output = tensor_model_parallel_all_reduce(input)
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
rms = auto_functionalized(
|
rms = auto_functionalized(
|
||||||
RMS_ADD_OP,
|
RMS_ADD_OP,
|
||||||
input=all_reduce_output,
|
input=allreduce_output,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
epsilon=self.epsilon,
|
epsilon=self.epsilon,
|
||||||
)
|
)
|
||||||
|
# input, residual
|
||||||
return rms[1], rms[2]
|
return rms[1], rms[2]
|
||||||
|
|
||||||
def replacement(residual: torch.Tensor, input: torch.Tensor,
|
def replacement(residual: torch.Tensor, input: torch.Tensor,
|
||||||
weight: torch.Tensor):
|
weight: torch.Tensor):
|
||||||
allreduce = auto_functionalized(
|
allreduce = auto_functionalized(
|
||||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
allreduce_in=input,
|
allreduce_in=input,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
|
norm_out=None,
|
||||||
|
quant_out=None,
|
||||||
|
scale_out=None,
|
||||||
rms_gamma=weight,
|
rms_gamma=weight,
|
||||||
rms_eps=self.epsilon,
|
rms_eps=self.epsilon,
|
||||||
norm_out=None,
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNorm,
|
||||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
)
|
)
|
||||||
|
# allreduce_in, residual
|
||||||
return allreduce[1], allreduce[2]
|
return allreduce[1], allreduce[2]
|
||||||
|
|
||||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
pm.fwd_only, pm_pass)
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (without residual)
|
||||||
|
+ static fp8 quant with fused flashinfer implementation.
|
||||||
|
Applies to allreduce + rmsnorm + quant before attn
|
||||||
|
in the first Transformer block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams):
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
self.quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
input = torch.zeros([1, 8, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
rmsnorm_result = torch.empty([1, 8, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
quant_result = torch.empty([1, 8, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.quant_dtype)
|
||||||
|
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||||
|
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||||
|
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
rmsnorm_result: torch.Tensor,
|
||||||
|
quant_result: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||||
|
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
|
||||||
|
result=rmsnorm_result,
|
||||||
|
input=all_reduce,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
|
||||||
|
quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP,
|
||||||
|
result=quant_result,
|
||||||
|
input=rmsnorm_out_tuple[1],
|
||||||
|
scale=scale)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output
|
||||||
|
return quant_out_tuple[1], all_reduce
|
||||||
|
|
||||||
|
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
|
||||||
|
quant_result: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
residual = torch.zeros_like(input)
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=result_rms,
|
||||||
|
quant_out=quant_result,
|
||||||
|
scale_out=None,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
|
||||||
|
scale_factor=scale,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output
|
||||||
|
return allreduce[4], allreduce[1]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (with residual)
|
||||||
|
+ static fp8 quant with fused flashinfer implementation.
|
||||||
|
Applies to o_proj + rmsnorm after attn + quant and
|
||||||
|
mlp + rmsnorm + quant before attn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams):
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
self.quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
residual = torch.empty([4, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
quant_result = torch.empty([4, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.quant_dtype)
|
||||||
|
scale = torch.empty([1, 1],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
return [
|
||||||
|
quant_result,
|
||||||
|
residual,
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
scale,
|
||||||
|
]
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
quant_result: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
|
|
||||||
|
fused_add_rmsnorm_out_tuple = \
|
||||||
|
auto_functionalized(
|
||||||
|
RMS_ADD_OP,
|
||||||
|
input=allreduce_output,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
quant_out_tuple = auto_functionalized(
|
||||||
|
STATIC_FP8_QUANT_OP,
|
||||||
|
result=quant_result,
|
||||||
|
input=fused_add_rmsnorm_out_tuple[1],
|
||||||
|
scale=scale)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output
|
||||||
|
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
|
||||||
|
|
||||||
|
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||||
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=None,
|
||||||
|
quant_out=quant_result,
|
||||||
|
scale_out=None,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
|
||||||
|
scale_factor=scale,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
# # quant_out, rms_norm_residual
|
||||||
|
return allreduce[4], allreduce[2]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (without residual)
|
||||||
|
+ static nvfp4 quant with fused flashinfer implementation.
|
||||||
|
Applies to allreduce + rmsnorm + quant before attn
|
||||||
|
in the first Transformer block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams):
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
input = torch.empty([1, 16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
rmsnorm_result = torch.empty([1, 16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
quant_result = torch.empty((16, 8),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
input_global_scale = torch.empty([1, 1],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||||
|
output_scale = torch.empty([128, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
return [
|
||||||
|
input, rmsnorm_result, quant_result, weight,
|
||||||
|
input_global_scale, output_scale
|
||||||
|
]
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
rmsnorm_result: torch.Tensor,
|
||||||
|
quant_result: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
input_global_scale: torch.Tensor,
|
||||||
|
output_scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||||
|
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
|
||||||
|
result=rmsnorm_result,
|
||||||
|
input=all_reduce,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
|
||||||
|
quant_out_tuple = auto_functionalized(
|
||||||
|
STATIC_FP4_QUANT_OP,
|
||||||
|
output=quant_result,
|
||||||
|
input=rmsnorm_out_tuple[1],
|
||||||
|
output_scale=output_scale,
|
||||||
|
input_scale=input_global_scale)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output, output_scale
|
||||||
|
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
|
||||||
|
|
||||||
|
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
|
||||||
|
quant_result: torch.Tensor, weight: torch.Tensor,
|
||||||
|
input_global_scale: torch.Tensor,
|
||||||
|
output_scale: torch.Tensor):
|
||||||
|
residual = torch.zeros_like(input)
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=result_rms,
|
||||||
|
quant_out=quant_result,
|
||||||
|
scale_out=output_scale,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
|
||||||
|
scale_factor=input_global_scale,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output, output_scale
|
||||||
|
return allreduce[4], allreduce[1], allreduce[5]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (with residual)
|
||||||
|
+ static nvfp4 quant with fused flashinfer implementation.
|
||||||
|
Applies to o_proj + rmsnorm after attn + quant and
|
||||||
|
mlp + rmsnorm + quant before attn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams):
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
residual = torch.empty([16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
weight = torch.empty([16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
quant_result = torch.empty((16, 8),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
input_global_scale = torch.empty([1, 1],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
output_scale = torch.empty([128, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
return [
|
||||||
|
quant_result,
|
||||||
|
residual,
|
||||||
|
input,
|
||||||
|
output_scale,
|
||||||
|
weight,
|
||||||
|
input_global_scale,
|
||||||
|
]
|
||||||
|
|
||||||
|
def pattern(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||||
|
input: torch.Tensor, output_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor, input_global_scale: torch.Tensor):
|
||||||
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
|
|
||||||
|
fused_add_rmsnorm_out_tuple = \
|
||||||
|
auto_functionalized(
|
||||||
|
RMS_ADD_OP,
|
||||||
|
input=allreduce_output,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
quant_out_tuple = auto_functionalized(
|
||||||
|
STATIC_FP4_QUANT_OP,
|
||||||
|
output=quant_result,
|
||||||
|
input=fused_add_rmsnorm_out_tuple[1],
|
||||||
|
output_scale=output_scale,
|
||||||
|
input_scale=input_global_scale)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output, output_scale
|
||||||
|
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[
|
||||||
|
2], quant_out_tuple[2]
|
||||||
|
|
||||||
|
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||||
|
input: torch.Tensor, output_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
input_global_scale: torch.Tensor):
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=None,
|
||||||
|
quant_out=quant_result,
|
||||||
|
scale_out=output_scale,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
|
||||||
|
scale_factor=input_global_scale,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
# quant_out, rms_norm_residual, output_scale
|
||||||
|
return allreduce[4], allreduce[2], allreduce[5]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AllReduceFusionPass(VllmInductorPass):
|
class AllReduceFusionPass(VllmInductorPass):
|
||||||
|
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
@ -671,13 +1081,16 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
self.tp_size,
|
self.tp_size,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
max_num_token = min(
|
||||||
|
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
|
||||||
|
(self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
|
||||||
|
config.compilation_config.pass_config.
|
||||||
|
fi_allreduce_fusion_max_token_num)
|
||||||
self.ipc_handles, workspace_tensor = (
|
self.ipc_handles, workspace_tensor = (
|
||||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||||
tp_rank=rank,
|
tp_rank=rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
max_token_num=config.compilation_config.pass_config.
|
max_token_num=max_num_token,
|
||||||
fi_allreduce_fusion_max_token_num,
|
|
||||||
hidden_dim=self.hidden_dim,
|
hidden_dim=self.hidden_dim,
|
||||||
group=self.group,
|
group=self.group,
|
||||||
use_fp32_lamport=use_fp32_lamport,
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
@ -689,12 +1102,38 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size,
|
||||||
use_fp32_lamport=use_fp32_lamport,
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
max_token_num=config.compilation_config.pass_config.
|
max_token_num=max_num_token,
|
||||||
fi_allreduce_fusion_max_token_num,
|
# fuse rms norm static fp8 quant fused op
|
||||||
)
|
# in fallback path, when we don't use flashinfer
|
||||||
|
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
||||||
|
|
||||||
for epsilon in [1e-5, 1e-6]:
|
for epsilon in [1e-5, 1e-6]:
|
||||||
AllReduceRMSNORMPattern(
|
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
if current_platform.has_device_capability(100):
|
||||||
|
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
AllReduceRMSNormPattern(
|
||||||
epsilon,
|
epsilon,
|
||||||
self.model_dtype,
|
self.model_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
@ -707,6 +1146,10 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
self.allreduce_params,
|
self.allreduce_params,
|
||||||
).register(self.patterns)
|
).register(self.patterns)
|
||||||
|
|
||||||
|
# WARNING: This is a hack to clear the pattern matcher cache
|
||||||
|
# and allow multiple values of epsilon.
|
||||||
|
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||||
|
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
@ -723,5 +1166,5 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
if flashinfer_comm is not None:
|
if flashinfer_comm is not None:
|
||||||
flashinfer_comm.trtllm_destroy_ipc_workspace(
|
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
|
||||||
self.ipc_handles, self.group)
|
self.ipc_handles, self.group)
|
||||||
|
|||||||
@ -4051,7 +4051,7 @@ class PassConfig:
|
|||||||
"""Whether to enable async TP."""
|
"""Whether to enable async TP."""
|
||||||
enable_fi_allreduce_fusion: bool = False
|
enable_fi_allreduce_fusion: bool = False
|
||||||
"""Whether to enable flashinfer allreduce fusion."""
|
"""Whether to enable flashinfer allreduce fusion."""
|
||||||
fi_allreduce_fusion_max_token_num: int = 1024
|
fi_allreduce_fusion_max_token_num: int = 16384
|
||||||
"""Max number of tokens to used in flashinfer allreduce fusion."""
|
"""Max number of tokens to used in flashinfer allreduce fusion."""
|
||||||
|
|
||||||
# TODO(luka) better pass enabling system.
|
# TODO(luka) better pass enabling system.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user