diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ec00bc7f108df..20d858cb15a11 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -299,6 +299,7 @@ steps: commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_sequence_parallelism.py - label: PyTorch Fullgraph Smoke Test # 9min source_file_dependencies: @@ -583,6 +584,8 @@ steps: - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' + # test sequence parallel + - pytest -v -s distributed/test_sequence_parallel.py # this test fails consistently. # TODO: investigate and fix # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 9f9b2d06b227c..27cd10b77491a 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -10,7 +10,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, VllmConfig from .backend import TestBackend @@ -49,13 +49,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, do_fusion: bool): torch.set_default_device("cuda") - config = CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_noop=True) - noop_pass = NoOpEliminationPass(config) - fusion_pass = FusionPass.instance(config) + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig(pass_config= \ + CompilationConfig.PassConfig(enable_fusion=do_fusion, + enable_noop=True)) + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = FusionPass.instance(vllm_config) passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass] - func_pass = FixFunctionalizationPass(config) + func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index efebf05b6b047..6a696fe0226b1 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) + vllm_config.compilation_config.pass_config = \ + CompilationConfig.PassConfig(enable_fusion=True, + enable_noop=True) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work - config = CompilationConfig.PassConfig(enable_fusion=True, - enable_noop=True) - noop_pass = NoOpEliminationPass(config) - fusion_pass = FusionPass.instance(config) + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 2c1ee4dc74806..673ebe8b6fdc0 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -6,7 +6,7 @@ import torch from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.pass_manager import PostGradPassManager -from vllm.config import CompilationConfig +from vllm.config import VllmConfig # dummy custom pass that doesn't inherit @@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph): # Should fail to add directly to the pass manager def test_bad_callable(): - config = CompilationConfig().pass_config + config = VllmConfig() pass_manager = PostGradPassManager() pass_manager.configure(config) @@ -43,7 +43,7 @@ class ProperPass(InductorPass): ], ) def test_pass_manager_uuid(callable): - config = CompilationConfig().pass_config + config = VllmConfig() pass_manager = PostGradPassManager() pass_manager.configure(config) @@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable): # UUID should be different due to config change config2 = copy.deepcopy(config) - config2.enable_fusion = not config2.enable_fusion + config2.compilation_config.pass_config.enable_fusion = not \ + config2.compilation_config.pass_config.enable_fusion pass_manager3 = PostGradPassManager() pass_manager3.configure(config2) pass_manager3.add(callable) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py new file mode 100644 index 0000000000000..79f5486dadcdd --- /dev/null +++ b/tests/compile/test_sequence_parallelism.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import vllm.envs as envs +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, + find_specified_fn, + find_specified_fn_maybe, is_func) +from vllm.compilation.sequence_parallelism import SequenceParallelismPass +from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, + VllmConfig) +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +from ..utils import multi_gpu_test +from .backend import TestBackend + +OPS_IN_MODEL_BEFORE = [ + torch.ops.vllm.all_reduce.default, +] + +OPS_IN_MODEL_AFTER = [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default, +] + +OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default] + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class TestModel(torch.nn.Module): + + def __init__(self, hidden_size=16, intermediate_size=32): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size))) + self.norm = RMSNorm(hidden_size, 1e-05) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + + def forward(self, hidden_states, residual): + """ + Forward pass implementing the operations in the FX graph + + Args: + hidden_states: Input tensor + residual: Residual tensor from previous layer + + Returns: + Tuple containing the output tensor + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + #matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # Tensor parallel all-reduce + all_reduce = tensor_model_parallel_all_reduce(mm) + + # layer normalization + norm_output, residual_output = self.norm(all_reduce, residual) + + return norm_output, residual_output + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [16]) +@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_sequence_parallelism_pass(batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn(fn, + args=(num_processes, batch_size, seq_len, + hidden_size, dtype), + nprocs=nprocs) + + run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) + + +def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, + batch_size: int, seq_len: int, + hidden_size: int, + dtype: torch.dtype): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # configure vllm config for SequenceParallelismPass + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig( + pass_config=CompilationConfig.PassConfig( + enable_sequence_parallelism=True, ), ) + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + + # this is a fake model name to construct the model config + # in the vllm_config, it's not really used. + model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42) + + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + backend_no_func = TestBackend(sequence_parallelism_pass) + func_pass = FixFunctionalizationPass(vllm_config) + backend_func = TestBackend(sequence_parallelism_pass, func_pass) + + model = TestModel(hidden_size, hidden_size * 2) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), + dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + + compiled_model_no_func = torch.compile(model, backend=backend_no_func) + compiled_model_no_func(hidden_states, residual) + compiled_model_func = torch.compile(model, backend=backend_func) + compiled_model_func(hidden_states, residual) + + # Check substitution worked + pre_nodes = backend_no_func.graph_pre_pass.nodes + post_nodes = backend_no_func.graph_post_pass.nodes + + # In pre-nodes, all reduce should be there, + # reduce scatter and all gather should not + for op in OPS_IN_MODEL_BEFORE: + find_specified_fn(pre_nodes, op) + for op in OPS_IN_MODEL_AFTER: + assert find_specified_fn_maybe(pre_nodes, op) is None + + # In post-nodes, reduce scatter and all gather should be there, + # all reduce should not + for op in OPS_IN_MODEL_AFTER: + find_specified_fn(post_nodes, op) + for op in OPS_IN_MODEL_BEFORE: + assert find_specified_fn_maybe(post_nodes, op) is None + + # check if the functionalization pass is applied + for op in OPS_IN_MODEL: + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, + op) is None # noqa: E501 + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in OPS_IN_MODEL: + if is_func(node, op): + found[op] = True + assert all(found[op] for op in OPS_IN_MODEL) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index ac6d6aae30063..8f4c3537e1586 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -14,7 +14,8 @@ import torch from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from ..utils import init_test_distributed_environment, multi_process_parallel @@ -47,6 +48,34 @@ def all_reduce_test_worker( torch.testing.assert_close(t, expected) +@ray.remote(num_gpus=1, max_calls=1) +def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, + pp_size: int, rank: int, + distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + num_elements = 8 + all_tensors = [ + torch.arange(num_elements, dtype=torch.float32, device="cuda") * + (r + 1) for r in range(tp_size) + ] + + index = rank % tp_size + partition_size = num_elements // tp_size + all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) + expected = all_reduce[index * partition_size:(index + 1) * partition_size] + t = all_tensors[index] + t = tensor_model_parallel_reduce_scatter(t, 0) + torch.testing.assert_close(t, expected) + + @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py new file mode 100644 index 0000000000000..19497ad9c1409 --- /dev/null +++ b/tests/distributed/test_sequence_parallel.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config import TaskOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_sequence_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + sp_enabled: bool + eager_mode: bool + chunked_prefill: bool + + +class SPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class SPTestSettings: + parallel_setups: list[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. + distributed_backends: list[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: list[str] + task: TaskOption + test_options: SPTestOptions + + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + + @staticmethod + def detailed( + *, + tp_base: int = 2, + multi_node_only: bool = False, + task: TaskOption = "auto", + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=True) + ], + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fast( + *, + tp_base: int = 2, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ], + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_id, parallel_setup, backend, vllm_major_version, + self.task, opts) + + +def _compare_sp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate", "encode"], + is_multimodal: bool, +): + ( + tp_size, + sp_enabled, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip("Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + compilation_config = { + 'level': 3, + 'custom_ops': ["+rms_norm"], + 'compile_sizes': [4, 8], + 'splitting_ops': [], + 'pass_config': { + 'enable_sequence_parallism': sp_enabled, + 'enable_noop': True, + 'enable_fusion': True, + }, + } + + tp_sp_env = tp_env = { + "VLLM_USE_V1": vllm_major_version, + } + + tp_sp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + str(compilation_config), + ] + + tp_env = { + "VLLM_USE_V1": vllm_major_version, + } + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + try: + compare_two_settings(model_id, + tp_sp_args, + tp_args, + tp_sp_env, + tp_env, + method=method) + except Exception: + testing_ray_compiled_graph = tp_sp_env is not None + if testing_ray_compiled_graph and vllm_major_version == "0": + # Ray Compiled Graph tests are flaky for V0, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") + else: + raise + + +SP_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(), +} + +SP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "meta-llama/Llama-3.2-1B-Instruct", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), + [ + params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in SP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_tp_sp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available, +): + _compare_sp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c493a764f56d6..a1d12b5175504 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -339,7 +339,7 @@ class VllmBackend: def configure_post_pass(self): config = self.compilation_config - self.post_grad_pass_manager.configure(config.pass_config) + self.post_grad_pass_manager.configure(self.vllm_config) # Post-grad custom passes are run using the post_grad_custom_post_pass # hook. If a pass for that hook exists, add it to the pass manager. diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index f91e4cf3e89ea..c5454ccdcbf7e 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -15,6 +15,8 @@ import vllm.envs as envs from vllm.config import VllmConfig from vllm.utils import is_torch_equal_or_newer +from .inductor_pass import pass_context + class CompilerInterface: """ @@ -312,11 +314,12 @@ class InductorAdaptor(CompilerInterface): torch._functorch.config.patch( enable_remote_autograd_cache=False)) - compiled_graph = compile_fx( - graph, - example_inputs, - inner_compile=hijacked_compile_fx_inner, - config_patches=current_config) + with pass_context(runtime_shape): + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch # compilation cache. So turn off the checks if we disable the diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index b46f5f52244fa..8f32fdb03f8b7 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import CompilationConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass): _instance: 'Optional[FusionPass]' = None @classmethod - def instance(cls, config: CompilationConfig.PassConfig): + def instance(cls, config: VllmConfig): """ Get the singleton instance of the FusionPass. If the instance exists, the config is updated but @@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass): if cls._instance is None: cls._instance = FusionPass(config) else: - cls._instance.config = config + cls._instance.pass_config = config.compilation_config.pass_config return cls._instance - def __init__(self, config: CompilationConfig.PassConfig): + def __init__(self, config: VllmConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" super().__init__(config) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index b9a8d3112e775..f9427e48ac315 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool: return node.op == "call_function" and node.target == target +# Returns the first specified node with the given op (if it exists) +def find_specified_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if node.target == op: + return node + return None + + +# Returns the first specified node with the given op +def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_specified_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + # Returns the first auto_functionalized node with the given op (if it exists) def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 00a2e89f21aeb..6cd7720fca2f9 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -4,6 +4,7 @@ import hashlib import inspect import json import types +from contextlib import contextmanager from typing import Any, Callable, Dict, Optional, Union import torch @@ -18,6 +19,34 @@ else: from .torch25_custom_graph_pass import ( # noqa: yapf Torch25CustomGraphPass as CustomGraphPass) +_pass_context = None + + +class PassContext: + + def __init__(self, runtime_shape: Optional[int]): + self.runtime_shape = runtime_shape + + +def get_pass_context() -> PassContext: + """Get the current pass context.""" + assert _pass_context is not None + return _pass_context + + +@contextmanager +def pass_context(runtime_shape: Optional[int]): + """A context manager that stores the current pass context, + usually it is a list of sizes to specialize. + """ + global _pass_context + prev_context = _pass_context + _pass_context = PassContext(runtime_shape) + try: + yield + finally: + _pass_context = prev_context + class InductorPass(CustomGraphPass): """ @@ -62,6 +91,9 @@ class InductorPass(CustomGraphPass): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() + def is_applicable_for_shape(self, shape: Optional[int]): + return True + class CallableInductorPass(InductorPass): """ diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 530a88b2b09ae..f8e8c4971cbb6 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -4,13 +4,15 @@ from typing import List from torch import fx as fx -from vllm.config import CompilationConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass -from .inductor_pass import CustomGraphPass, InductorPass +from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass +from .sequence_parallelism import SequenceParallelismPass +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass): """ def __init__(self): - self.passes: List[InductorPass] = [] + self.passes: List[VllmInductorPass] = [] def __call__(self, graph: fx.Graph): + shape = get_pass_context().runtime_shape for pass_ in self.passes: - pass_(graph) + if pass_.is_applicable_for_shape(shape): + pass_(graph) # always run fix_functionalization last self.fix_functionalization(graph) - def configure(self, pass_config: CompilationConfig.PassConfig): - self.pass_config = pass_config - if pass_config.enable_noop: - self.passes += [NoOpEliminationPass(pass_config)] + def configure(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] - if pass_config.enable_fusion: - self.passes += [FusionPass.instance(pass_config)] + if self.pass_config.enable_fusion: + self.passes += [FusionPass.instance(config)] - self.fix_functionalization = FixFunctionalizationPass(pass_config) + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + + self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py new file mode 100644 index 0000000000000..95db63d34f7ea --- /dev/null +++ b/vllm/compilation/sequence_parallelism.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class AllReduceRMSNormPattern: + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + self.epsilon = epsilon + self.dtype = dtype + self.device = device + + +class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], + device=self.device, + dtype=torch.long) + unsqueeze = torch.rand([1, 8, 1], device=self.device, \ + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], device=self.device, \ + dtype=self.dtype) + permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) + + return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + all_reduce = tensor_model_parallel_all_reduce(where) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=permute, + input=all_reduce, + weight=arg3_1, + epsilon=self.epsilon, + ) + + return rmsnorm[1], all_reduce + + def replacement( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + where, dim=0, world_size=tp_size, group_name=tp.unique_name) + + rmsnorm_result = torch.empty_like(reduce_scatter) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=rmsnorm_result, + input=reduce_scatter, + weight=arg3_1, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + return rmsnorm[1], rmsnorm[2] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + # TODO is it possible to extract epsilon from somewhere + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + return all_gather, rmsnorm[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + return rmsnorm[1] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + # TODO is it possible to extract epsilon from somewhere + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + normalized = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class SequenceParallelismPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="sequence_parallelism_pass") + for epsilon in [1e-5, 1e-6]: + EmbeddingAllReduceRMSNormPattern( + epsilon, self.dtype, self.device).register(self.patterns) + + MiddleAllReduceRMSNormPattern(epsilon, self.dtype, + self.device).register(self.patterns) + + LastAllReduceRMSNormPattern(epsilon, self.dtype, + self.device).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + # only do replace for specific shapes + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.dump_graph(graph, "before_sequence_parallelism_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_sequence_parallelism_pass") diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 98ed6f1472a45..e8bffb406f148 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -4,7 +4,7 @@ import time import torch -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, VllmConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( @@ -24,16 +24,19 @@ class VllmInductorPass(InductorPass): It provides timing, logging, and dumping utilities. """ - def __init__(self, config: CompilationConfig.PassConfig): - self.config = config + def __init__(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + self.dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config \ + else None self.pass_name = self.__class__.__name__ def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): - if stage in self.config.dump_graph_stages or always: + if stage in self.pass_config.dump_graph_stages or always: # Make sure filename includes rank in the distributed setting parallel = p_is_init() and get_tp_world_size() > 1 rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py" logger.info("%s printing graph to %s", self.pass_name, filepath) with open(filepath, "w") as f: diff --git a/vllm/config.py b/vllm/config.py index fcbf962ac685c..3b39573fc7db1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3405,11 +3405,13 @@ class CompilationConfig(BaseModel): - enable_fusion: whether to enable the custom fusion pass. - enable_noop: whether to enable the custom no-op elimination pass. TODO(luka) better pass enabling system. + - enable_sequence_parallelism: whether to enable sequence parallelism. """ dump_graph_stages: list[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) enable_fusion: bool = True enable_noop: bool = True + enable_sequence_parallelism: bool = False def uuid(self): """ @@ -3418,7 +3420,8 @@ class CompilationConfig(BaseModel): Do not include dump_graph_* in the hash - they don't affect compilation. """ - dict_ = self.model_dump(include={"enable_fusion", "enable_noop"}) + dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \ + "enable_sequence_parallelism"}) return InductorPass.hash_dict(dict_) def model_post_init(self, __context: Any) -> None: @@ -3840,6 +3843,8 @@ class VllmConfig: if self.compilation_config is None: self.compilation_config = CompilationConfig() + if self.compilation_config.pass_config.enable_sequence_parallelism: + self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ not self.model_config.enforce_eager: # NOTE(woosuk): Currently, we use inductor because the piecewise @@ -3847,7 +3852,8 @@ class VllmConfig: # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. # FIXME(rob): Add function to set all of these. - self.compilation_config.custom_ops = ["none"] + if not self.compilation_config.custom_ops: + self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 @@ -3856,6 +3862,18 @@ class VllmConfig: self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() + if self.parallel_config is not None and \ + self.parallel_config.tensor_parallel_size > 1 and \ + self.parallel_config.pipeline_parallel_size > 1 and \ + self.compilation_config is not None and \ + self.compilation_config.pass_config is not None and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + logger.warning_once( + "Sequence parallelism is not supported with pipeline " + "parallelism. Disabling sequence parallelism.") + self.compilation_config.pass_config.\ + enable_sequence_parallelism = False + self._set_cudagraph_sizes() if self.cache_config is not None and \ @@ -3895,6 +3913,26 @@ class VllmConfig: if not self.instance_id: self.instance_id = random_uuid()[:5] + def update_sizes_for_sequence_parallelism(self, + possible_sizes: list) -> list: + # remove the sizes that not multiple of tp_size when + # enable sequence parallelism + removed_sizes = [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size != 0 + ] + if removed_sizes: + logger.warning( + "Batch sizes %s are removed because they are not " + "multiple of tp_size %d when " + "sequence parallelism is enabled", removed_sizes, + self.parallel_config.tensor_parallel_size) + + return [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size == 0 + ] + def _set_cudagraph_sizes(self): """ cudagraph batchsize padding logic: @@ -3932,6 +3970,11 @@ class VllmConfig: not self.model_config.enforce_eager: possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + possible_sizes = self.update_sizes_for_sequence_parallelism( + possible_sizes) + # find the minimum size that is larger than max_num_seqs, # which then becomes the max_batchsize_to_capture larger_sizes = [ @@ -3955,6 +3998,11 @@ class VllmConfig: not self.model_config.enforce_eager: batch_size_capture_list = [1, 2, 4 ] + [i for i in range(8, 513, 8)] + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + batch_size_capture_list = \ + self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ size for size in batch_size_capture_list diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0228264f91f9a..894a0fafb6403 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -19,6 +19,12 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor, return get_tp_group().all_gather(input_, dim) +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """Reduce-Scatter the input tensor across model parallel group.""" + return get_tp_group().reduce_scatter(input_, dim) + + def tensor_model_parallel_gather(input_: torch.Tensor, dst: int = 0, dim: int = -1) -> Optional[torch.Tensor]: diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index eb12f8834b419..240313b98c88b 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -61,6 +61,40 @@ class DeviceCommunicatorBase: input_size[dim + 1:]) return output_tensor + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output_tensor = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + # Perform reduce-scatter operation + torch.distributed.reduce_scatter_tensor(output_tensor, + input_tensor, + group=self.device_group) + + # Reshape before returning + return output_tensor.movedim(0, dim).contiguous() + def gather(self, input_: torch.Tensor, dst: int = 0, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 07c9ff5060924..8bca278f3888b 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase): torch.distributed.all_reduce(out, group=self.device_group) return out + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d0ac7e92f9c89..cb9658ce10043 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -113,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) +def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group.reduce_scatter(tensor, dim) + + +def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] // world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + +def all_gather(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group.all_gather(tensor, dim) + + +def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] * world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + if supports_custom_op(): from vllm.platforms import current_platform direct_register_custom_op( @@ -123,6 +155,20 @@ if supports_custom_op(): dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="reduce_scatter", + op_func=reduce_scatter, + mutates_args=[], + fake_impl=reduce_scatter_fake, + ) + + direct_register_custom_op( + op_name="all_gather", + op_func=all_gather, + mutates_args=[], + fake_impl=all_gather_fake, + ) + class GroupCoordinator: """ @@ -322,6 +368,18 @@ class GroupCoordinator: return self.device_communicator.all_gather(input_, dim) + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + return self.device_communicator.reduce_scatter(input_, dim) + def gather(self, input_: torch.Tensor, dst: int = 0, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5775701b941e5..e3d8b94fe9d7e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1027,7 +1027,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens) else: # Eager mode. - num_input_tokens = num_scheduled_tokens + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + from vllm.utils import round_up + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + else: + num_input_tokens = num_scheduled_tokens attn_metadata.num_input_tokens = num_input_tokens # _prepare_inputs may reorder the batch, so we must gather multi