[Feature] support sequence parallelism using compilation pass (#16155)

Signed-off-by: cascade812 <cascade812@outlook.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
cascade 2025-04-27 06:29:35 -07:00 committed by GitHub
parent ed7a29d9f8
commit 690fe019f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1072 additions and 44 deletions

View File

@ -299,6 +299,7 @@ steps:
commands:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- label: PyTorch Fullgraph Smoke Test # 9min
source_file_dependencies:
@ -583,6 +584,8 @@ steps:
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
# test sequence parallel
- pytest -v -s distributed/test_sequence_parallel.py
# this test fails consistently.
# TODO: investigate and fix
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py

View File

@ -10,7 +10,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, VllmConfig
from .backend import TestBackend
@ -49,13 +49,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
do_fusion: bool):
torch.set_default_device("cuda")
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config= \
CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(config)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

View File

@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \
CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)

View File

@ -6,7 +6,7 @@ import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig
from vllm.config import VllmConfig
# dummy custom pass that doesn't inherit
@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph):
# Should fail to add directly to the pass manager
def test_bad_callable():
config = CompilationConfig().pass_config
config = VllmConfig()
pass_manager = PostGradPassManager()
pass_manager.configure(config)
@ -43,7 +43,7 @@ class ProperPass(InductorPass):
],
)
def test_pass_manager_uuid(callable):
config = CompilationConfig().pass_config
config = VllmConfig()
pass_manager = PostGradPassManager()
pass_manager.configure(config)
@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable):
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.enable_fusion = not config2.enable_fusion
config2.compilation_config.pass_config.enable_fusion = not \
config2.compilation_config.pass_config.enable_fusion
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn,
find_specified_fn_maybe, is_func)
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
from ..utils import multi_gpu_test
from .backend import TestBackend
OPS_IN_MODEL_BEFORE = [
torch.ops.vllm.all_reduce.default,
]
OPS_IN_MODEL_AFTER = [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default,
]
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
class TestModel(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)))
self.norm = RMSNorm(hidden_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
#matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
# Tensor parallel all-reduce
all_reduce = tensor_model_parallel_all_reduce(mm)
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
return norm_output, residual_output
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(num_processes, batch_size, seq_len,
hidden_size, dtype),
nprocs=nprocs)
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
batch_size: int, seq_len: int,
hidden_size: int,
dtype: torch.dtype):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(
enable_sequence_parallelism=True, ), )
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype=dtype,
seed=42)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
backend_no_func = TestBackend(sequence_parallelism_pass)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(sequence_parallelism_pass, func_pass)
model = TestModel(hidden_size, hidden_size * 2)
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
# Check substitution worked
pre_nodes = backend_no_func.graph_pre_pass.nodes
post_nodes = backend_no_func.graph_post_pass.nodes
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
for op in OPS_IN_MODEL_BEFORE:
find_specified_fn(pre_nodes, op)
for op in OPS_IN_MODEL_AFTER:
assert find_specified_fn_maybe(pre_nodes, op) is None
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
for op in OPS_IN_MODEL_AFTER:
find_specified_fn(post_nodes, op)
for op in OPS_IN_MODEL_BEFORE:
assert find_specified_fn_maybe(post_nodes, op) is None
# check if the functionalization pass is applied
for op in OPS_IN_MODEL:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in OPS_IN_MODEL:
if is_func(node, op):
found[op] = True
assert all(found[op] for op in OPS_IN_MODEL)

View File

@ -14,7 +14,8 @@ import torch
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from ..utils import init_test_distributed_environment, multi_process_parallel
@ -47,6 +48,34 @@ def all_reduce_test_worker(
torch.testing.assert_close(t, expected)
@ray.remote(num_gpus=1, max_calls=1)
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tp_size)
]
index = rank % tp_size
partition_size = num_elements // tp_size
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
t = all_tensors[index]
t = tensor_model_parallel_reduce_scatter(t, 0)
torch.testing.assert_close(t, expected)
@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(
monkeypatch: pytest.MonkeyPatch,

View File

@ -0,0 +1,296 @@
# SPDX-License-Identifier: Apache-2.0
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
important to set the distributed backend to "mp" to avoid Ray scheduling
all workers in a node other than the head node, which can cause the test
to fail.
"""
import json
import os
from dataclasses import dataclass
from typing import Literal, NamedTuple, Optional
import pytest
from vllm.config import TaskOption
from vllm.logger import init_logger
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_sequence_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
sp_enabled: bool
eager_mode: bool
chunked_prefill: bool
class SPTestOptions(NamedTuple):
multi_node_only: bool
load_format: Optional[str] = None
@dataclass
class SPTestSettings:
parallel_setups: list[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: list[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: list[str]
task: TaskOption
test_options: SPTestOptions
def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")
@staticmethod
def detailed(
*,
tp_base: int = 2,
multi_node_only: bool = False,
task: TaskOption = "auto",
load_format: Optional[str] = None,
):
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True)
],
distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"],
task=task,
test_options=SPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)
@staticmethod
def fast(
*,
tp_base: int = 2,
task: TaskOption = "auto",
multi_node_only: bool = False,
load_format: Optional[str] = None,
):
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"],
task=task,
test_options=SPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)
def iter_params(self, model_id: str):
opts = self.test_options
for parallel_setup in self.parallel_setups:
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_id, parallel_setup, backend, vllm_major_version,
self.task, opts)
def _compare_sp(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: SPTestOptions,
num_gpus_available: int,
*,
method: Literal["generate", "encode"],
is_multimodal: bool,
):
(
tp_size,
sp_enabled,
eager_mode,
chunked_prefill,
) = parallel_setup
multi_node_only, load_format = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
if load_format == "dummy":
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}
if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")
pp_size = 1
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting")
common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if task != "auto":
common_args.extend(["--task", task])
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
compilation_config = {
'level': 3,
'custom_ops': ["+rms_norm"],
'compile_sizes': [4, 8],
'splitting_ops': [],
'pass_config': {
'enable_sequence_parallism': sp_enabled,
'enable_noop': True,
'enable_fusion': True,
},
}
tp_sp_env = tp_env = {
"VLLM_USE_V1": vllm_major_version,
}
tp_sp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
distributed_backend,
"--compilation_config",
str(compilation_config),
]
tp_env = {
"VLLM_USE_V1": vllm_major_version,
}
tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
"mp",
]
try:
compare_two_settings(model_id,
tp_sp_args,
tp_args,
tp_sp_env,
tp_env,
method=method)
except Exception:
testing_ray_compiled_graph = tp_sp_env is not None
if testing_ray_compiled_graph and vllm_major_version == "0":
# Ray Compiled Graph tests are flaky for V0,
# so we don't want to fail the test
logger.exception("Ray Compiled Graph tests failed")
else:
raise
SP_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(),
}
SP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"meta-llama/Llama-3.2-1B-Instruct",
]
@pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
"task", "test_options"),
[
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in SP_TEST_MODELS
],
)
@create_new_process_for_each_test()
def test_tp_sp_generation(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: SPTestOptions,
num_gpus_available,
):
_compare_sp(model_id,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="generate",
is_multimodal=False)

View File

@ -339,7 +339,7 @@ class VllmBackend:
def configure_post_pass(self):
config = self.compilation_config
self.post_grad_pass_manager.configure(config.pass_config)
self.post_grad_pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.

View File

@ -15,6 +15,8 @@ import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer
from .inductor_pass import pass_context
class CompilerInterface:
"""
@ -312,11 +314,12 @@ class InductorAdaptor(CompilerInterface):
torch._functorch.config.patch(
enable_remote_autograd_cache=False))
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config)
with pass_context(runtime_shape):
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config)
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
# compilation cache. So turn off the checks if we disable the

View File

@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import CompilationConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass):
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: CompilationConfig.PassConfig):
def instance(cls, config: VllmConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass):
if cls._instance is None:
cls._instance = FusionPass(config)
else:
cls._instance.config = config
cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance
def __init__(self, config: CompilationConfig.PassConfig):
def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)

View File

@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
# Returns the first specified node with the given op (if it exists)
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]:
for node in nodes:
if node.target == op:
return node
return None
# Returns the first specified node with the given op
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_specified_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]:

View File

@ -4,6 +4,7 @@ import hashlib
import inspect
import json
import types
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Union
import torch
@ -18,6 +19,34 @@ else:
from .torch25_custom_graph_pass import ( # noqa: yapf
Torch25CustomGraphPass as CustomGraphPass)
_pass_context = None
class PassContext:
def __init__(self, runtime_shape: Optional[int]):
self.runtime_shape = runtime_shape
def get_pass_context() -> PassContext:
"""Get the current pass context."""
assert _pass_context is not None
return _pass_context
@contextmanager
def pass_context(runtime_shape: Optional[int]):
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(runtime_shape)
try:
yield
finally:
_pass_context = prev_context
class InductorPass(CustomGraphPass):
"""
@ -62,6 +91,9 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_shape(self, shape: Optional[int]):
return True
class CallableInductorPass(InductorPass):
"""

View File

@ -4,13 +4,15 @@ from typing import List
from torch import fx as fx
from vllm.config import CompilationConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import CustomGraphPass, InductorPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass):
"""
def __init__(self):
self.passes: List[InductorPass] = []
self.passes: List[VllmInductorPass] = []
def __call__(self, graph: fx.Graph):
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
pass_(graph)
if pass_.is_applicable_for_shape(shape):
pass_(graph)
# always run fix_functionalization last
self.fix_functionalization(graph)
def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config
if pass_config.enable_noop:
self.passes += [NoOpEliminationPass(pass_config)]
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.fix_functionalization = FixFunctionalizationPass(pass_config)
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)

View File

@ -0,0 +1,266 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class AllReduceRMSNormPattern:
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
self.epsilon = epsilon
self.dtype = dtype
self.device = device
class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
def get_inputs(self):
arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]],
device=self.device,
dtype=torch.long)
unsqueeze = torch.rand([1, 8, 1], device=self.device, \
dtype=self.dtype) > 0.5
full_default = torch.zeros([1, 8, 4], device=self.device, \
dtype=self.dtype)
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
arg2_1: torch.Tensor,
mul_6: torch.Tensor,
unsqueeze: torch.Tensor,
full_default: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
where = torch.ops.aten.where.self(unsqueeze, full_default,
embedding)
all_reduce = tensor_model_parallel_all_reduce(where)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=permute,
input=all_reduce,
weight=arg3_1,
epsilon=self.epsilon,
)
return rmsnorm[1], all_reduce
def replacement(
arg2_1: torch.Tensor,
mul_6: torch.Tensor,
unsqueeze: torch.Tensor,
full_default: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
where = torch.ops.aten.where.self(unsqueeze, full_default,
embedding)
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
where, dim=0, world_size=tp_size, group_name=tp.unique_name)
rmsnorm_result = torch.empty_like(reduce_scatter)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=rmsnorm_result,
input=reduce_scatter,
weight=arg3_1,
epsilon=self.epsilon,
)
all_gather = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
return all_gather, reduce_scatter
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=all_reduce,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
return rmsnorm[1], rmsnorm[2]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
# TODO is it possible to extract epsilon from somewhere
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=reduce_scatter,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
all_gather = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
return all_gather, rmsnorm[2]
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=all_reduce,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
return rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
# TODO is it possible to extract epsilon from somewhere
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=reduce_scatter,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
normalized = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
return normalized
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class SequenceParallelismPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass")
for epsilon in [1e-5, 1e-6]:
EmbeddingAllReduceRMSNormPattern(
epsilon, self.dtype, self.device).register(self.patterns)
MiddleAllReduceRMSNormPattern(epsilon, self.dtype,
self.device).register(self.patterns)
LastAllReduceRMSNormPattern(epsilon, self.dtype,
self.device).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
def __call__(self, graph: fx.Graph):
self.dump_graph(graph, "before_sequence_parallelism_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_sequence_parallelism_pass")

View File

@ -4,7 +4,7 @@ import time
import torch
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, VllmConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
@ -24,16 +24,19 @@ class VllmInductorPass(InductorPass):
It provides timing, logging, and dumping utilities.
"""
def __init__(self, config: CompilationConfig.PassConfig):
self.config = config
def __init__(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
self.dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config \
else None
self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
if stage in self.config.dump_graph_stages or always:
if stage in self.pass_config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py"
logger.info("%s printing graph to %s", self.pass_name, filepath)
with open(filepath, "w") as f:

View File

@ -3405,11 +3405,13 @@ class CompilationConfig(BaseModel):
- enable_fusion: whether to enable the custom fusion pass.
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
- enable_sequence_parallelism: whether to enable sequence parallelism.
"""
dump_graph_stages: list[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_noop: bool = True
enable_sequence_parallelism: bool = False
def uuid(self):
"""
@ -3418,7 +3420,8 @@ class CompilationConfig(BaseModel):
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
"enable_sequence_parallelism"})
return InductorPass.hash_dict(dict_)
def model_post_init(self, __context: Any) -> None:
@ -3840,6 +3843,8 @@ class VllmConfig:
if self.compilation_config is None:
self.compilation_config = CompilationConfig()
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if envs.VLLM_USE_V1 and self.model_config is not None and \
not self.model_config.enforce_eager:
# NOTE(woosuk): Currently, we use inductor because the piecewise
@ -3847,7 +3852,8 @@ class VllmConfig:
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
# FIXME(rob): Add function to set all of these.
self.compilation_config.custom_ops = ["none"]
if not self.compilation_config.custom_ops:
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.use_inductor = True
self.compilation_config.cudagraph_num_of_warmups = 1
@ -3856,6 +3862,18 @@ class VllmConfig:
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()
if self.parallel_config is not None and \
self.parallel_config.tensor_parallel_size > 1 and \
self.parallel_config.pipeline_parallel_size > 1 and \
self.compilation_config is not None and \
self.compilation_config.pass_config is not None and \
self.compilation_config.pass_config.enable_sequence_parallelism:
logger.warning_once(
"Sequence parallelism is not supported with pipeline "
"parallelism. Disabling sequence parallelism.")
self.compilation_config.pass_config.\
enable_sequence_parallelism = False
self._set_cudagraph_sizes()
if self.cache_config is not None and \
@ -3895,6 +3913,26 @@ class VllmConfig:
if not self.instance_id:
self.instance_id = random_uuid()[:5]
def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
removed_sizes = [
size for size in possible_sizes
if size % self.parallel_config.tensor_parallel_size != 0
]
if removed_sizes:
logger.warning(
"Batch sizes %s are removed because they are not "
"multiple of tp_size %d when "
"sequence parallelism is enabled", removed_sizes,
self.parallel_config.tensor_parallel_size)
return [
size for size in possible_sizes
if size % self.parallel_config.tensor_parallel_size == 0
]
def _set_cudagraph_sizes(self):
"""
cudagraph batchsize padding logic:
@ -3932,6 +3970,11 @@ class VllmConfig:
not self.model_config.enforce_eager:
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
possible_sizes = self.update_sizes_for_sequence_parallelism(
possible_sizes)
# find the minimum size that is larger than max_num_seqs,
# which then becomes the max_batchsize_to_capture
larger_sizes = [
@ -3955,6 +3998,11 @@ class VllmConfig:
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
batch_size_capture_list = \
self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
batch_size_capture_list = [
size for size in batch_size_capture_list

View File

@ -19,6 +19,12 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
return get_tp_group().all_gather(input_, dim)
def tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_tp_group().reduce_scatter(input_, dim)
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:

View File

@ -61,6 +61,40 @@ class DeviceCommunicatorBase:
input_size[dim + 1:])
return output_tensor
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_tensor = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(output_tensor,
input_tensor,
group=self.device_group)
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def gather(self,
input_: torch.Tensor,
dst: int = 0,

View File

@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""

View File

@ -113,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.reduce_scatter(tensor, dim)
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] // world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.all_gather(tensor, dim)
def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] * world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op(
@ -123,6 +155,20 @@ if supports_custom_op():
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
)
direct_register_custom_op(
op_name="all_gather",
op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake,
)
class GroupCoordinator:
"""
@ -322,6 +368,18 @@ class GroupCoordinator:
return self.device_communicator.all_gather(input_, dim)
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
return self.device_communicator.reduce_scatter(input_, dim)
def gather(self,
input_: torch.Tensor,
dst: int = 0,

View File

@ -1027,7 +1027,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = num_scheduled_tokens
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
# _prepare_inputs may reorder the batch, so we must gather multi