[Feature] Support sequence parallelism for static fp8 quantization (#19181)

Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
cascade 2025-06-23 13:09:02 -07:00 committed by GitHub
parent d0132f025d
commit e6327c9b3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 531 additions and 195 deletions

View File

@ -6,7 +6,9 @@ import torch
import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
@ -14,12 +16,15 @@ from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
from ..utils import multi_gpu_test
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -30,13 +35,16 @@ prompts = [
class TestModel(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
def __init__(self,
hidden_size=16,
intermediate_size=32,
vllm_config: VllmConfig = None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)))
self.norm = RMSNorm(hidden_size, 1e-05)
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
@ -79,32 +87,138 @@ class TestModel(torch.nn.Module):
return [torch.ops._C.fused_add_rms_norm.default]
class TestQuantModel(torch.nn.Module):
def __init__(self,
hidden_size=16,
intermediate_size=32,
vllm_config: VllmConfig = None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.vllm_config = vllm_config
self.gate_proj = torch.nn.Parameter(torch.empty(
(intermediate_size, hidden_size)),
requires_grad=False)
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True,
use_per_token_if_dynamic=False)
self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,
# which expects a column-major layout.
self.w = torch.rand(hidden_size,
intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
#matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
# Tensor parallel all-reduce
all_reduce = tensor_model_parallel_all_reduce(mm)
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
# for static input quantization
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
fp8_linear_result = self.fp8_linear.apply(norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(
norm_output.device))
return fp8_linear_result, residual_output
def ops_in_model_before(self):
ops_to_remove = [torch.ops.vllm.all_reduce.default
] # Always removed by SP
# The following are only removed if fusion happens
if self.vllm_config and self.vllm_config.compilation_config \
.pass_config.enable_fusion:
ops_to_remove.extend([
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.static_scaled_fp8_quant.default,
])
return ops_to_remove
def ops_in_model_after(self):
ops_to_add = [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default
]
# The following is only added if fusion happens
if self.vllm_config and self.vllm_config.compilation_config \
.pass_config.enable_fusion:
ops_to_add.append(
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
return ops_to_add
def ops_in_model(self):
if self.vllm_config and self.vllm_config.compilation_config \
.pass_config.enable_fusion:
# If fusion happens, the fused op is the one
# we check for (de)functionalization
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
] # noqa: E501
else:
# If no fusion, the original ops are checked
return [
torch.ops._C.fused_add_rms_norm.default,
# TODO functionalization pass does not handle this yet
# torch.ops._C.static_scaled_fp8_quant.default,
]
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module],
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype,
enable_fusion: bool):
num_processes = 2
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(num_processes, batch_size, seq_len,
hidden_size, dtype),
args=(num_processes, test_model_cls,
batch_size, seq_len, hidden_size,
dtype, enable_fusion),
nprocs=nprocs)
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
batch_size: int, seq_len: int,
hidden_size: int,
dtype: torch.dtype):
def sequence_parallelism_pass_on_test_model(
local_rank: int, world_size: int,
test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype, enable_fusion: bool):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
enable_sequence_parallelism=True))
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True)) # NoOp needed for fusion
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model,
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name,
task="auto",
tokenizer=model,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=True,
dtype=dtype,
seed=42)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
backend_no_func = TestBackend(sequence_parallelism_pass)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(sequence_parallelism_pass, func_pass)
model = TestModel(hidden_size, hidden_size * 2)
passes_for_backend = [noop_pass, sequence_parallelism_pass]
if enable_fusion:
fusion_pass = FusionPass.instance(vllm_config)
passes_for_backend.append(fusion_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size,
hidden_size * 2,
vllm_config=vllm_config)
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)

View File

@ -28,7 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
sp_enabled: bool
enable_fusion: bool
eager_mode: bool
chunked_prefill: bool
@ -67,49 +67,18 @@ class SPTestSettings:
task: TaskOption = "auto",
load_format: Optional[str] = None,
):
parallel_setups = []
for eager_mode_val in [False, True]:
for pp_multiplier in [1, 2]:
for chunked_prefill_val in [False, True]:
parallel_setups.append(
ParallelSetup(tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
enable_fusion=False,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val))
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True)
],
parallel_setups=parallel_setups,
distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"],
task=task,
@ -126,19 +95,44 @@ class SPTestSettings:
multi_node_only: bool = False,
load_format: Optional[str] = None,
):
parallel_setups = []
for eager_mode_val in [False, True]:
for pp_multiplier in [1, 2]:
for chunked_prefill_val in [False, True]:
parallel_setups.append(
ParallelSetup(tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
enable_fusion=False,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val))
return SPTestSettings(
parallel_setups=[
parallel_setups=parallel_setups,
distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"],
task=task,
test_options=SPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)
@staticmethod
def fp8_quant(
*,
tp_base: int = 2,
pp_base: int = 1,
task: TaskOption = "auto",
multi_node_only: bool = False,
load_format: Optional[str] = None,
):
parallel_setups = []
for fusion_val in [False, True]:
parallel_setups.append(
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
],
enable_fusion=fusion_val,
eager_mode=True,
chunked_prefill=False))
return SPTestSettings(
parallel_setups=parallel_setups,
distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"],
task=task,
@ -171,7 +165,7 @@ def _compare_sp(
(
tp_size,
pp_size,
sp_enabled,
enable_fusion,
eager_mode,
chunked_prefill,
) = parallel_setup
@ -240,9 +234,9 @@ def _compare_sp(
'compile_sizes': [4, 8],
'splitting_ops': [],
'pass_config': {
'enable_sequence_parallelism': sp_enabled,
'enable_sequence_parallelism': True,
'enable_fusion': enable_fusion,
'enable_noop': True,
'enable_fusion': True,
},
}
@ -291,12 +285,14 @@ def _compare_sp(
SP_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
}
SP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"meta-llama/Llama-3.2-1B-Instruct",
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
]

View File

@ -193,7 +193,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),

View File

@ -345,8 +345,8 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
# 0 is always None
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
self.insert_fused_node(fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
**kwargs)
**kwargs,
epsilon=rms_node.kwargs["epsilon"])
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):

View File

@ -51,15 +51,15 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]

View File

@ -12,91 +12,142 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class AllReduceRMSNormPattern:
class _RMSNormAndQuantOpHelper:
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
def __init__(self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: Optional[torch._ops.OpOverload] = None,
**kwargs):
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.quant_op = quant_op
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=result_buffer,
input=input_tensor,
weight=weight_tensor,
epsilon=self.epsilon)
def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor,
weight_tensor):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=input_tensor,
residual=residual_tensor,
weight=weight_tensor,
epsilon=self.epsilon)
def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer,
quant_result_buffer, input_tensor,
weight_tensor, scale_tensor):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer,
input_tensor,
weight_tensor)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=rmsnorm_out_tuple[1],
scale=scale_tensor)
return quant_out_tuple
def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer,
input_tensor, residual_tensor,
weight_tensor, scale_tensor):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
input_tensor, residual_tensor, weight_tensor)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale_tensor)
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
"""Helper for sequence parallelism patterns."""
def __init__(self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: Optional[torch._ops.OpOverload] = None,
**kwargs):
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x)
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.reduce_scatter.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp_group.unique_name)
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp_group.unique_name)
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]],
device=self.device,
dtype=torch.long)
unsqueeze = torch.rand([1, 8, 1], device=self.device, \
dtype=self.dtype) > 0.5
full_default = torch.zeros([1, 8, 4], device=self.device, \
dtype=self.dtype)
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1]
return [input, permute, arg3_1]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
arg2_1: torch.Tensor,
mul_6: torch.Tensor,
unsqueeze: torch.Tensor,
full_default: torch.Tensor,
input: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
where = torch.ops.aten.where.self(unsqueeze, full_default,
embedding)
all_reduce = tensor_model_parallel_all_reduce(where)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=permute,
input=all_reduce,
weight=arg3_1,
epsilon=self.epsilon,
)
all_reduce = self._all_reduce(input)
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
return rmsnorm[1], all_reduce
def replacement(
arg2_1: torch.Tensor,
mul_6: torch.Tensor,
unsqueeze: torch.Tensor,
full_default: torch.Tensor,
input: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
where = torch.ops.aten.where.self(unsqueeze, full_default,
embedding)
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
where, dim=0, world_size=tp_size, group_name=tp.unique_name)
reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(reduce_scatter)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=rmsnorm_result,
input=reduce_scatter,
weight=arg3_1,
epsilon=self.epsilon,
)
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter,
arg3_1)
all_gather = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
all_gather = self._all_gather(rmsnorm[1])
return all_gather, reduce_scatter
@ -104,7 +155,7 @@ class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
pm.fwd_only, pm_pass)
class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
@ -127,16 +178,9 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=all_reduce,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights)
return rmsnorm[1], rmsnorm[2]
def replacement(
@ -144,32 +188,17 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
# TODO is it possible to extract epsilon from somewhere
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=reduce_scatter,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
all_gather = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights)
all_gather = self._all_gather(rmsnorm[1])
return all_gather, rmsnorm[2]
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
@ -192,16 +221,9 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=all_reduce,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights)
return rmsnorm[1]
def replacement(
@ -209,26 +231,185 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights)
normalized = self._all_gather(rmsnorm[1])
return normalized
# TODO is it possible to extract epsilon from somewhere
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=reduce_scatter,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
normalized = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
op: torch._ops.OpOverload):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty([1, 8, 4],
device=self.device,
dtype=self.dtype)
quant_result = torch.empty([1, 8, 4],
device=self.device,
dtype=FP8_DTYPE)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
all_reduce = self._all_reduce(input)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, all_reduce, weight, scale)
return static_fp8[1], all_reduce
def replacement(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(reduce_scatter,
dtype=rmsnorm_result.dtype)
quant_result = torch.empty_like(
rmsnorm_result, # Output of RMSNorm
dtype=quant_result.dtype)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, reduce_scatter, weight, scale)
all_gather = self._all_gather(static_fp8[1])
return all_gather, reduce_scatter
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
op: torch._ops.OpOverload):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
result,
residual,
mm_1,
rms_norm_weights,
scale,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
result, all_reduce, residual, rms_norm_weights, scale)
return static_fp8[1], rmsnorm_residual_out
def replacement(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter,
dtype=result.dtype)
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
scale)
all_gather = self._all_gather(static_fp8[1])
return all_gather, rmsnorm_residual_out
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
op: torch._ops.OpOverload):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
result,
residual,
mm_1,
rms_norm_weights,
scale,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
result, all_reduce, residual, rms_norm_weights, scale)
return static_fp8[1]
def replacement(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter,
dtype=result.dtype)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
scale)
normalized = self._all_gather(static_fp8[1])
return normalized
pm.register_replacement(pattern, replacement, self.get_inputs(),
@ -236,21 +417,54 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
class SequenceParallelismPass(VllmInductorPass):
"""
This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by
an RMSNorm (or RMSNorm and then Quantization) operation.
These patterns are replaced with a ReduceScatter operation, followed by
a local RMSNorm/Quantization, and then an AllGather operation.
The general transformation is:
Input -> AllReduce -> RMSNorm -> Output
becomes
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
While this pass itself does not directly yield performance improvements,
it lays the groundwork for subsequent fusion passes, such as
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model
performance.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass")
for epsilon in [1e-5, 1e-6]:
EmbeddingAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device).register(self.patterns)
# RMSNorm + Static FP8 quantization patterns
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()

View File

@ -3802,11 +3802,11 @@ class PassConfig:
its own stages (before, after, maybe in-between)."""
dump_graph_dir: Path = Path(".")
"""Directory to dump the graphs."""
enable_fusion: bool = True
enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion: bool = False
"""Whether to enable the custom attention+quant fusion pass."""
enable_noop: bool = True
enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism: bool = False
"""Whether to enable sequence parallelism."""
@ -4451,8 +4451,6 @@ class VllmConfig:
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
# is set to True, full CUDA graphs will be used.
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_noop = False
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()