From b59dd19b55036050cf491501614fb46cedf5c5c1 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Mon, 13 Oct 2025 18:15:34 -0700 Subject: [PATCH] [compile] Enable sequence parallelism for full cuda graph without specifying compile sizes (#26681) Signed-off-by: angelayi --- vllm/compilation/collective_fusion.py | 11 +++++++++-- vllm/compilation/inductor_pass.py | 2 +- vllm/compilation/pass_manager.py | 4 +++- vllm/compilation/sequence_parallelism.py | 20 +++++++++++++++++++- vllm/compilation/vllm_inductor_pass.py | 2 ++ 5 files changed, 34 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 1dc8888607f5..7c85c89bcd7a 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -431,8 +431,15 @@ class AsyncTPPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) - def is_applicable_for_shape(self, shape: int | None) -> bool: - # only do replace for specific shapes + def is_applicable(self, shape: int | None) -> bool: + # This pass is applied on top of the sequence parallelism pass. + # It inherits the same applicability condition as `SequenceParallelismPass`. + # See `SequenceParallelismPass.is_applicable` for more details. + if ( + not self.compilation_config.splitting_ops + or self.compilation_config.use_inductor_graph_partition + ): + return True tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index b9ec3cf6c5ed..4b263fa6f5a2 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -96,7 +96,7 @@ class InductorPass(CustomGraphPass): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_shape(self, shape: int | None): + def is_applicable(self, shape: int | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e323fa1f7734..55fe235e2d2c 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -71,9 +71,11 @@ class PostGradPassManager(CustomGraphPass): shape = get_pass_context().runtime_shape for pass_ in self.passes: - if pass_.is_applicable_for_shape(shape): + if pass_.is_applicable(shape): pass_(graph) VllmInductorPass.dump_prefix += 1 + else: + logger.debug("Skipping %s with shape %s", pass_, shape) # post-cleanup goes before fix_functionalization # because it requires a functional graph diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 8ff530cebd82..31624a8fdcc0 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -482,7 +482,25 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable_for_shape(self, shape: int | None) -> bool: + def is_applicable(self, shape: int | None) -> bool: + # When sequence parallelism is enabled, the residual tensor from RMSNorm + # needs to be split along the sequence dimension. However, this dimension + # is symbolic during piecewise compilation, and splitting symbolic shapes + # is not supported. + # + # This pass is therefore only applied when the sequence dimension is + # concrete: + # 1. In full-graph compilation mode (no Dynamo splitting ops are used). + # For this case we always pad num_tokens to be a multiple of + # tensor_parallel_size, so there's no need to check shape % tp_size == 0. + # 2. For specific shape provided during compilation (e.g., from + # `compile_sizes`), which must be divisible by the tensor-parallel + # size. + if ( + not self.compilation_config.splitting_ops + or self.compilation_config.use_inductor_graph_partition + ): + return True tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index ad83e7b3e0c2..beac928b5d71 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -3,6 +3,7 @@ import functools import operator import time +import weakref from typing import ClassVar import regex as re @@ -28,6 +29,7 @@ class VllmInductorPass(InductorPass): """Keep track of pass index for debug dump ordering.""" def __init__(self, config: VllmConfig): + self.compilation_config = weakref.proxy(config.compilation_config) self.pass_config = config.compilation_config.pass_config self.model_dtype = config.model_config.dtype if config.model_config else None self.device = config.device_config.device if config.device_config else None