[compile] Enable sequence parallelism matching w/o custom ops enabled (#27126)

Signed-off-by: angelayi <yiangela7@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
Angela Yi 2025-11-15 03:46:12 -08:00 committed by GitHub
parent 173b356abf
commit f36292dbee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 472 additions and 444 deletions

View File

@ -478,10 +478,11 @@ steps:
- vllm/ - vllm/
- tests/compile - tests/compile
commands: commands:
# fp8 kv scales not supported on sm89, tested on Blackwell instead
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time # Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a - # Wrap with quotes to escape yaml and avoid starting -k string with a -
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
- label: Cudagraph test - label: Cudagraph test
timeout_in_minutes: 20 timeout_in_minutes: 20
@ -925,7 +926,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_flashinfer.py
- label: Blackwell Fusion Tests # 30 min - label: Blackwell Fusion & Compile Tests # 30 min
timeout_in_minutes: 40 timeout_in_minutes: 40
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
gpu: b200 gpu: b200
@ -946,7 +947,9 @@ steps:
- pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_all_reduce.py
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
# Wrap with quotes to escape yaml # Wrap with quotes to escape yaml
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'" - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell Fusion E2E Tests # 30 min - label: Blackwell Fusion E2E Tests # 30 min
timeout_in_minutes: 40 timeout_in_minutes: 40
@ -969,8 +972,6 @@ steps:
- nvidia-smi - nvidia-smi
# Run all e2e fusion tests # Run all e2e fusion tests
- pytest -v -s tests/compile/test_fusions_e2e.py - pytest -v -s tests/compile/test_fusions_e2e.py
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell GPT-OSS Eval - label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60 timeout_in_minutes: 60
@ -1266,7 +1267,8 @@ steps:
- pytest -v -s tests/compile/test_async_tp.py - pytest -v -s tests/compile/test_async_tp.py
- pytest -v -s tests/compile/test_sequence_parallelism.py - pytest -v -s tests/compile/test_sequence_parallelism.py
- pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'"
- pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py - pytest -v -s tests/v1/distributed/test_dbo.py

View File

@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import flat_product, multi_gpu_test from ..utils import flat_product, multi_gpu_test
is_blackwell = lambda: current_platform.is_device_capability(100)
"""Are we running on Blackwell, a lot of tests depend on it"""
class Matches(NamedTuple):
attention_fusion: int = 0
allreduce_fusion: int = 0
sequence_parallel: int = 0
async_tp: int = 0
class ModelBackendTestCase(NamedTuple): class ModelBackendTestCase(NamedTuple):
model_name: str model_name: str
model_kwargs: dict[str, Any] model_kwargs: dict[str, Any]
backend: AttentionBackendEnum backend: AttentionBackendEnum
attention_fusions: int matches: Matches
allreduce_fusions: int | None = None
MODELS_FP8: list[ModelBackendTestCase] = [] MODELS_FP8: list[ModelBackendTestCase] = []
@ -38,17 +47,33 @@ if current_platform.is_cuda():
ModelBackendTestCase( ModelBackendTestCase(
# Use smaller model for L40s in CI # Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024), # TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
backend=AttentionBackendEnum.TRITON_ATTN, # so FI attention+fp8_quant is at least tested once
attention_fusions=32, model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
allreduce_fusions=65, backend=AttentionBackendEnum.FLASHINFER
if is_blackwell()
else AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER, # TODO FlashInfer attn broken on Hopper with kvcache=fp8:
attention_fusions=48, # https://github.com/vllm-project/vllm/issues/28568
allreduce_fusions=96, # TODO FlashInfer attn broken on Blackwell for llama4:
# https://github.com/vllm-project/vllm/issues/28604
backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=48,
allreduce_fusion=96,
sequence_parallel=96,
async_tp=95, # mlp is moe, no fusion there
),
), ),
] ]
@ -57,8 +82,12 @@ if current_platform.is_cuda():
model_name="nvidia/Llama-3.1-8B-Instruct-FP4", model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER, backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=32, matches=Matches(
allreduce_fusions=65, attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
), ),
] ]
@ -68,15 +97,23 @@ if current_platform.is_cuda():
model_name="meta-llama/Llama-3.1-8B-Instruct", model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0, matches=Matches(
allreduce_fusions=65, attention_fusion=0,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B", model_name="Qwen/Qwen3-30B-A3B",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0, matches=Matches(
allreduce_fusions=97, attention_fusion=0,
allreduce_fusion=97,
sequence_parallel=97,
async_tp=96, # MLP is MoE, half the fusions of dense
),
), ),
] ]
@ -86,19 +123,19 @@ elif current_platform.is_rocm():
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32, matches=Matches(attention_fusion=32),
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.ROCM_ATTN, backend=AttentionBackendEnum.ROCM_ATTN,
attention_fusions=32, matches=Matches(attention_fusion=32),
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32, matches=Matches(attention_fusion=32),
), ),
] ]
@ -106,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, model_kwargs, backend, " "model_name, model_kwargs, backend, matches, custom_ops",
"attention_fusions, allreduce_fusions, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl # quant_fp4 only has the custom impl
@ -118,15 +154,14 @@ def test_attn_quant(
model_name: str, model_name: str,
model_kwargs: dict[str, Any], model_kwargs: dict[str, Any],
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
attention_fusions: int, matches: Matches,
allreduce_fusions: int,
custom_ops: str, custom_ops: str,
inductor_graph_partition: bool, inductor_graph_partition: bool,
caplog_mp_spawn, caplog_mp_spawn,
monkeypatch, monkeypatch,
): ):
if backend == AttentionBackendEnum.FLASHINFER and ( if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer() not is_blackwell() or not has_flashinfer()
): ):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
@ -169,12 +204,12 @@ def test_attn_quant(
with caplog_mp_spawn(logging.DEBUG) as log_holder: with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs) run_model(compilation_config, model_name, **model_kwargs)
matches = re.findall( log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text, log_holder.text,
) )
assert len(matches) == 1, log_holder.text assert len(log_matches) == 1, log_holder.text
assert int(matches[0]) == attention_fusions assert int(log_matches[0]) == matches.attention_fusion
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"] CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
@ -187,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, model_kwargs, backend, " "model_name, model_kwargs, backend, matches, custom_ops",
"attention_fusions, allreduce_fusions, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models # Toggle RMSNorm and QuantFP8 for FP8 models
list( list(
flat_product( flat_product(
@ -209,8 +243,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str, model_name: str,
model_kwargs: dict, model_kwargs: dict,
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
attention_fusions: int, matches: Matches,
allreduce_fusions: int,
custom_ops: str, custom_ops: str,
inductor_graph_partition: bool, inductor_graph_partition: bool,
caplog_mp_spawn, caplog_mp_spawn,
@ -219,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9") pytest.skip("Inductor graph partition requires torch>=2.9")
if "fp4" in model_name.lower() and not is_blackwell():
pytest.skip("NVFP4 quant requires Blackwell")
if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
# FlashInfer attn fusion requires Blackwell
matches = matches._replace(attention_fusion=0)
custom_ops_list = custom_ops.split(",") if custom_ops else [] custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition: if inductor_graph_partition:
@ -258,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
run_model( run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
) )
matches = re.findall( log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text, log_holder.text,
) )
assert len(matches) == 2, log_holder.text assert len(log_matches) == 2, log_holder.text
assert int(matches[0]) == attention_fusions assert int(log_matches[0]) == matches.attention_fusion
assert int(matches[1]) == attention_fusions assert int(log_matches[1]) == matches.attention_fusion
matches = re.findall( log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns", r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text, log_holder.text,
) )
assert len(matches) == 2, log_holder.text assert len(log_matches) == 2, log_holder.text
assert int(matches[0]) == allreduce_fusions assert int(log_matches[0]) == matches.allreduce_fusion
assert int(matches[1]) == allreduce_fusions assert int(log_matches[1]) == matches.allreduce_fusion
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
)
)
# Toggle RMSNorm for FP4 models and unquant models
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="sequence parallel only tested on CUDA",
)
def test_tp2_attn_quant_async_tp(
model_name: str,
model_kwargs: dict,
backend: AttentionBackendEnum,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if is_blackwell():
# TODO: https://github.com/vllm-project/vllm/issues/27893
pytest.skip("Blackwell is not supported for AsyncTP pass")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
if "fp4" in model_name.lower() and not is_blackwell():
pytest.skip("NVFP4 quant requires Blackwell")
if backend == AttentionBackendEnum.FLASHINFER:
if not has_flashinfer():
pytest.skip("FlashInfer backend requires flashinfer installed")
if not is_blackwell():
# FlashInfer attn fusion requires Blackwell
matches = matches._replace(attention_fusion=0)
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
custom_ops=custom_ops_list,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_sequence_parallelism=True,
enable_async_tp=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion
log_matches = re.findall(
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.sequence_parallel
assert int(log_matches[1]) == matches.sequence_parallel
log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.async_tp
assert int(log_matches[1]) == matches.async_tp
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):

View File

@ -5,15 +5,15 @@ import pytest
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
CUDAGraphMode,
DeviceConfig, DeviceConfig,
ModelConfig, ModelConfig,
PassConfig, PassConfig,
@ -27,6 +27,7 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
@ -43,172 +44,157 @@ prompts = [
] ]
class TestModel(torch.nn.Module): class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32): def __init__(self, hidden_size=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.eps = eps
self.gate_proj = torch.nn.Parameter( self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
torch.empty((intermediate_size, hidden_size)) self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
)
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
def forward(self, hidden_states, residual): def forward(self, x):
""" z = torch.relu(x)
Forward pass implementing the operations in the FX graph x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
Args: z2 = torch.mm(y, self.w[0])
hidden_states: Input tensor x2 = tensor_model_parallel_all_reduce(z2)
residual: Residual tensor from previous layer
Returns: y2, resid = self.norm[1](x2, resid)
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
# matrix multiplication z3 = torch.mm(y2, self.w[1])
permute = self.gate_proj.permute(1, 0) x3 = tensor_model_parallel_all_reduce(z3)
mm = torch.mm(view, permute)
# Tensor parallel all-reduce y3, resid = self.norm[2](x3, resid)
all_reduce = tensor_model_parallel_all_reduce(mm)
# layer normalization z4 = torch.mm(y3, self.w[2])
norm_output, residual_output = self.norm(all_reduce, residual) x4 = tensor_model_parallel_all_reduce(z4)
return norm_output, residual_output y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self): def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default] return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self): def ops_in_model_after(self):
return [ return [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default, torch.ops.vllm.all_gather.default,
torch.ops.vllm.reduce_scatter.default,
] ]
def ops_in_model(self): def ops_in_model(self):
return [torch.ops._C.fused_add_rms_norm.default] if RMSNorm.enabled():
return [
torch.ops._C.rms_norm.default,
torch.ops._C.fused_add_rms_norm.default,
]
else:
return []
class TestQuantModel(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32): def __init__(self, hidden_size=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.vllm_config = get_current_vllm_config() self.vllm_config = get_current_vllm_config()
self.gate_proj = torch.nn.Parameter( self.hidden_size = hidden_size
torch.empty((intermediate_size, hidden_size)), requires_grad=False self.eps = eps
) self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.norm = RMSNorm(intermediate_size, 1e-05) self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
# Initialize weights self.w = [
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
]
self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
self.scale = torch.rand(1, dtype=torch.float32) act_quant_group_shape=GroupShape.PER_TENSOR,
# Create a weight that is compatible with torch._scaled_mm,
# which expects a column-major layout.
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
# Tensor parallel all-reduce
all_reduce = tensor_model_parallel_all_reduce(mm)
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
) )
return fp8_linear_result, residual_output self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def ops_in_model_before(self): def forward(self, hidden_states):
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP # avoid having graph input be an arg to a pattern directly
# The following are only removed if fusion happens z = torch.relu(hidden_states)
if ( x = resid = tensor_model_parallel_all_reduce(z)
self.vllm_config y = self.norm[0](x)
and self.vllm_config.compilation_config.pass_config.enable_fusion
): z2 = self.fp8_linear.apply(
ops_to_remove.extend( y, self.w[0], self.wscale[0], input_scale=self.scale[0]
[ )
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.static_scaled_fp8_quant.default, x2 = tensor_model_parallel_all_reduce(z2)
] y2, resid = self.norm[1](x2, resid)
)
return ops_to_remove z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self): def ops_in_model_after(self):
ops_to_add = [ return [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default, torch.ops.vllm.all_gather.default,
torch.ops.vllm.reduce_scatter.default,
]
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
] ]
# The following is only added if fusion happens
if (
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
return ops_to_add
def ops_in_model(self): def ops_in_model(self):
if ( if self.vllm_config.compilation_config.pass_config.enable_fusion:
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
# If fusion happens, the fused op is the one
# we check for (de)functionalization
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
else: elif RMSNorm.enabled():
# If no fusion, the original ops are checked
return [ return [
torch.ops._C.fused_add_rms_norm.default, torch.ops._C.fused_add_rms_norm.default,
# TODO functionalization pass does not handle this yet
# torch.ops._C.static_scaled_fp8_quant.default,
] ]
elif self.fp8_linear.quant_fp8.enabled():
return [
torch.ops._C.static_scaled_fp8_quant.default,
]
else:
return []
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel]) @pytest.mark.parametrize(
"test_model_cls, custom_ops",
[
(TestAllReduceRMSNormModel, "+rms_norm"),
(TestAllReduceRMSNormModel, "-rms_norm"),
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
],
)
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False]) @pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass( def test_sequence_parallelism_pass(
test_model_cls: type[torch.nn.Module], test_model_cls: type[torch.nn.Module],
custom_ops: str,
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
enable_fusion: bool, enable_fusion: bool,
dynamic: bool,
): ):
num_processes = 2 num_processes = 2
@ -220,11 +206,13 @@ def test_sequence_parallelism_pass(
args=( args=(
num_processes, num_processes,
test_model_cls, test_model_cls,
custom_ops,
batch_size, batch_size,
seq_len, seq_len,
hidden_size, hidden_size,
dtype, dtype,
enable_fusion, enable_fusion,
dynamic,
), ),
nprocs=nprocs, nprocs=nprocs,
) )
@ -236,11 +224,13 @@ def sequence_parallelism_pass_on_test_model(
local_rank: int, local_rank: int,
world_size: int, world_size: int,
test_model_cls: type[torch.nn.Module], test_model_cls: type[torch.nn.Module],
custom_ops: str,
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
enable_fusion: bool, enable_fusion: bool,
dynamic: bool,
): ):
current_platform.seed_everything(0) current_platform.seed_everything(0)
@ -264,12 +254,16 @@ def sequence_parallelism_pass_on_test_model(
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
custom_ops_list = custom_ops.split(",") if custom_ops else []
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
splitting_ops=[], # avoid automatic rms_norm enablement
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
custom_ops=custom_ops_list,
pass_config=PassConfig( pass_config=PassConfig(
enable_sequence_parallelism=True, enable_sequence_parallelism=True,
enable_fusion=enable_fusion, enable_fusion=enable_fusion,
enable_noop=True, enable_noop=True,
) ),
) # NoOp needed for fusion ) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda")) device_config = DeviceConfig(device=torch.device("cuda"))
@ -289,7 +283,6 @@ def sequence_parallelism_pass_on_test_model(
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
assert ( assert (
sequence_parallelism_pass.compilation_config.splitting_ops sequence_parallelism_pass.compilation_config.splitting_ops
@ -310,38 +303,29 @@ def sequence_parallelism_pass_on_test_model(
passes_for_backend.append(cleanup_pass) passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend) backend = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size, hidden_size * 2) model = test_model_cls(hidden_size)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func) if dynamic:
compiled_model_no_func(hidden_states, residual) torch._dynamo.mark_dynamic(hidden_states, 0)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
assert sequence_parallelism_pass.matched_count == 1 compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
assert sequence_parallelism_pass.matched_count == 4
# In pre-nodes, all reduce should be there, # In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not # reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before()) for op in model.ops_in_model_before():
assert backend.op_count(op, before=True) == 4
# In post-nodes, reduce scatter and all gather should be there, # In post-nodes, reduce scatter and all gather should be there,
# all reduce should not # all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after()) for op in model.ops_in_model_after():
assert backend.op_count(op, before=False) == 4
# check if the functionalization pass is applied
for op in model.ops_in_model(): for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op) find_auto_fn(backend.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())

View File

@ -18,6 +18,7 @@ import pytest
from vllm.config.compilation import CompilationMode from vllm.config.compilation import CompilationMode
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
@ -161,6 +162,7 @@ def _compare_sp(
test_options: SPTestOptions, test_options: SPTestOptions,
num_gpus_available: int, num_gpus_available: int,
use_inductor_graph_partition: bool, use_inductor_graph_partition: bool,
enable_async_tp: bool,
*, *,
method: Literal["generate", "encode"], method: Literal["generate", "encode"],
is_multimodal: bool, is_multimodal: bool,
@ -244,10 +246,10 @@ def _compare_sp(
compilation_config = { compilation_config = {
"mode": CompilationMode.VLLM_COMPILE, "mode": CompilationMode.VLLM_COMPILE,
"custom_ops": ["+rms_norm"],
"compile_sizes": [4, 8], "compile_sizes": [4, 8],
"pass_config": { "pass_config": {
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"enable_async_tp": enable_async_tp,
"enable_fusion": enable_fusion, "enable_fusion": enable_fusion,
"enable_noop": True, "enable_noop": True,
}, },
@ -307,6 +309,7 @@ SP_TEST_MODELS = [
], ],
) )
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) @pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_tp_sp_generation( def test_tp_sp_generation(
model_id: str, model_id: str,
@ -316,10 +319,19 @@ def test_tp_sp_generation(
test_options: SPTestOptions, test_options: SPTestOptions,
num_gpus_available, num_gpus_available,
use_inductor_graph_partition: bool, use_inductor_graph_partition: bool,
enable_async_tp: bool,
): ):
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+") pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
# Skip FP8 SP-only test on sm89 (compute capability 8.9)
if (
"fp8" in model_id.lower()
and current_platform.get_device_capability() < (9, 0)
and (not enable_async_tp)
):
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
_compare_sp( _compare_sp(
model_id, model_id,
parallel_setup, parallel_setup,
@ -328,6 +340,7 @@ def test_tp_sp_generation(
test_options, test_options,
num_gpus_available, num_gpus_available,
use_inductor_graph_partition, use_inductor_graph_partition,
enable_async_tp=enable_async_tp,
method="generate", method="generate",
is_multimodal=False, is_multimodal=False,
) )

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
import torch.fx as fx import torch.fx as fx
@ -10,98 +12,28 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .noop_elimination import NoOpEliminationPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
class _RMSNormAndQuantOpHelper: def get_first_out_wrapper(fn):
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" @functools.wraps(fn)
def wrapper(*args):
return fn(*args)[0]
def __init__( return wrapper
self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: torch._ops.OpOverload | None = None,
**kwargs,
):
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.quant_op = quant_op
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=result_buffer,
input=input_tensor,
weight=weight_tensor,
epsilon=self.epsilon,
)
def _functional_fused_add_rmsnorm(
self, input_tensor, residual_tensor, weight_tensor
):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=input_tensor,
residual=residual_tensor,
weight=weight_tensor,
epsilon=self.epsilon,
)
def _functional_rmsnorm_then_quant(
self,
rmsnorm_result_buffer,
quant_result_buffer,
input_tensor,
weight_tensor,
scale_tensor,
):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
rmsnorm_out_tuple = self._functional_rmsnorm(
rmsnorm_result_buffer, input_tensor, weight_tensor
)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=rmsnorm_out_tuple[1],
scale=scale_tensor,
)
return quant_out_tuple
def _functional_fused_add_rmsnorm_then_quant(
self,
quant_result_buffer,
input_tensor,
residual_tensor,
weight_tensor,
scale_tensor,
):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
input_tensor, residual_tensor, weight_tensor
)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale_tensor,
)
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns.""" """Helper for sequence parallelism patterns."""
def __init__( def __init__(
@ -109,10 +41,10 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str,
quant_op: torch._ops.OpOverload | None = None,
**kwargs,
): ):
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@ -131,36 +63,34 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self): def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, permute, arg3_1] return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor, arg3_1: torch.Tensor,
): ):
all_reduce = self._all_reduce(input) all_reduce = self._all_reduce(input)
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
return rmsnorm[1], all_reduce return rmsnorm, all_reduce
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor, arg3_1: torch.Tensor,
): ):
reduce_scatter = self._reduce_scatter(input) reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(reduce_scatter) rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1) all_gather = self._all_gather(rmsnorm)
all_gather = self._all_gather(rmsnorm[1])
return all_gather, reduce_scatter return all_gather, reduce_scatter
pm.register_replacement( pm.register_replacement(
@ -169,6 +99,10 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self): def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
@ -188,67 +122,34 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights: torch.Tensor, rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1) all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm( rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
all_reduce, residual, rms_norm_weights return rmsnorm[0], rmsnorm[1]
)
return rmsnorm[1], rmsnorm[2]
def replacement( def replacement(
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor, rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1) reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm( residual = residual[0 : reduce_scatter.size(0), ...]
reduce_scatter, residual, rms_norm_weights rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
) all_gather = self._all_gather(rmsnorm[0])
all_gather = self._all_gather(rmsnorm[1]) # shape of residual changes but that's fine,
return all_gather, rmsnorm[2] # next node is already slicing it, now becomes a noop
return all_gather, rmsnorm[1]
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
) )
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)
return rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)
normalized = self._all_gather(rmsnorm[1])
return normalized
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
) )
@ -257,52 +158,41 @@ FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__( def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload self,
epsilon: float,
dtype: torch.dtype,
device: str,
): ):
super().__init__(epsilon, dtype, device, quant_op=op) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self): def get_inputs(self):
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
weight = torch.empty([4], device=self.device, dtype=self.dtype) weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale] return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
all_reduce = self._all_reduce(input) all_reduce = self._all_reduce(input)
static_fp8 = self._functional_rmsnorm_then_quant( rms = self.rmsnorm_matcher(all_reduce, weight)
rmsnorm_result, quant_result, all_reduce, weight, scale quant, _ = self.quant_matcher(rms, scale)
) return quant, all_reduce
return static_fp8[1], all_reduce
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
reduce_scatter = self._reduce_scatter(input) reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight)
rmsnorm_result = torch.empty_like( quant, _ = self.quant_matcher(rms, scale)
reduce_scatter, dtype=rmsnorm_result.dtype all_gather = self._all_gather(quant)
)
quant_result = torch.empty_like(
rmsnorm_result, # Output of RMSNorm
dtype=quant_result.dtype,
)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, reduce_scatter, weight, scale
)
all_gather = self._all_gather(static_fp8[1])
return all_gather, reduce_scatter return all_gather, reduce_scatter
@ -312,118 +202,64 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__( def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload super().__init__(epsilon, dtype, device)
): self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
super().__init__(epsilon, dtype, device, quant_op=op) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self): def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [ return [residual, mm_1, rms_norm_weights, scale]
result,
residual,
mm_1,
rms_norm_weights,
scale,
]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern( def pattern(
result: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor, rms_norm_weights: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1) all_reduce = self._all_reduce(mm_1)
static_fp8, rmsnorm_residual_out = ( rms, residual_out = self.rmsnorm_matcher(
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 all_reduce, rms_norm_weights, residual
result, all_reduce, residual, rms_norm_weights, scale
)
) )
return static_fp8[1], rmsnorm_residual_out quant, _ = self.quant_matcher(rms, scale)
return quant, residual_out
def replacement( def replacement(
result: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
mm_1: torch.Tensor, mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor, rms_norm_weights: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# add a temporary slice which will become a noop
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1) reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) residual = residual[0 : reduce_scatter.size(0), ...]
static_fp8, rmsnorm_residual_out = ( rms, residual_out = self.rmsnorm_matcher(
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 reduce_scatter, rms_norm_weights, residual
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
)
) )
all_gather = self._all_gather(static_fp8[1]) quant, _ = self.quant_matcher(rms, scale)
return all_gather, rmsnorm_residual_out all_gather = self._all_gather(quant)
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, residual_out
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
) )
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
result,
residual,
mm_1,
rms_norm_weights,
scale,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
result, all_reduce, residual, rms_norm_weights, scale
)
return static_fp8[1]
def replacement(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
)
normalized = self._all_gather(static_fp8[1])
return normalized
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
) )
@ -445,27 +281,45 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model significantly reduce communication overhead and improve overall model
performance. performance.
This pass splits up the residual tensor across TP ranks and hence divides its size.
Because the pattern matcher starts at the end of the graph, the replacement
contains a slice that temporarily conforms the input residual to the correct size.
After all patterns have been matched, we use a NoOpEliminationPass to clean up
what have now become no-op slices.
Note that an older version of the pass did not need this as it operated only on
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
mismatched shapes during replacement. So this approach has the same assumption that
correctness is only maintained if all rms_norm operations are split across ranks.
Correctness-wise, this is approach strictly better than before - before,
the graph was incorrect semantically and shape-wise during the pass.
With this approach there's only semantic incorrectness during the pass.
Both approaches restore a correct graph once all patterns are matched.
""" """
@enable_fake_mode @enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
super().__init__(config) super().__init__(config)
# Used to cleanup redundant views created temporarily
# to circumvent residual shape change issues
self.noop_cleanup = NoOpEliminationPass(config)
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass" pass_name="sequence_parallelism_pass"
) )
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns # RMSNorm + Static FP8 quantization patterns
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern( FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op epsilon, self.model_dtype, self.device
).register(self.patterns) ).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern( MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op epsilon, self.model_dtype, self.device
).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op
).register(self.patterns) ).register(self.patterns)
# Normal RMSNorm patterns # Normal RMSNorm patterns
@ -477,9 +331,6 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
epsilon, self.model_dtype, self.device epsilon, self.model_dtype, self.device
).register(self.patterns) ).register(self.patterns)
LastAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool: def is_applicable(self, shape: int | None) -> bool:
@ -508,3 +359,5 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph) self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count) logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes
self.noop_cleanup(graph)

View File

@ -445,8 +445,6 @@ class VllmConfig:
# and requires it to be enabled. # and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp: if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = True self.compilation_config.pass_config.enable_sequence_parallelism = True
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.support_static_graph_mode(): if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default # if cudagraph_mode is not explicitly set by users, set default
@ -620,6 +618,32 @@ class VllmConfig:
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
self.compilation_config.set_splitting_ops_for_v1() self.compilation_config.set_splitting_ops_for_v1()
if self.compilation_config.pass_config.enable_sequence_parallelism:
# With pipeline parallelism or dynamo partitioning,
# native rms norm tracing errors due to incorrect residual shape.
# Use custom rms norm to unblock. In the future,
# the pass will operate on higher-level IR to avoid the issue.
# TODO: https://github.com/vllm-project/vllm/issues/27894
is_fullgraph = (
self.compilation_config.use_inductor_graph_partition
or len(self.compilation_config.splitting_ops) == 0
)
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
if "-rms_norm" not in self.compilation_config.custom_ops:
self.compilation_config.custom_ops.append("+rms_norm")
else:
regime = (
"Dynamo partition"
if not is_fullgraph
else "pipeline parallelism"
)
logger.warning_once(
"Sequence parallelism not supported with"
"native rms_norm when using %s, "
"this will likely lead to an error.",
regime,
)
# final check of cudagraph mode after all possible updates # final check of cudagraph mode after all possible updates
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
if ( if (