[compile] Enable sequence parallelism for full cuda graph without specifying compile sizes (#26681)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi 2025-10-13 18:15:34 -07:00 committed by GitHub
parent 3e051bda82
commit b59dd19b55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 5 deletions

View File

@ -431,8 +431,15 @@ class AsyncTPPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: int | None) -> bool:
# only do replace for specific shapes
def is_applicable(self, shape: int | None) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0

View File

@ -96,7 +96,7 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_shape(self, shape: int | None):
def is_applicable(self, shape: int | None):
return True

View File

@ -71,9 +71,11 @@ class PostGradPassManager(CustomGraphPass):
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape):
if pass_.is_applicable(shape):
pass_(graph)
VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with shape %s", pass_, shape)
# post-cleanup goes before fix_functionalization
# because it requires a functional graph

View File

@ -482,7 +482,25 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: int | None) -> bool:
def is_applicable(self, shape: int | None) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes
# is not supported.
#
# This pass is therefore only applied when the sequence dimension is
# concrete:
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
# For this case we always pad num_tokens to be a multiple of
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
# 2. For specific shape provided during compilation (e.g., from
# `compile_sizes`), which must be divisible by the tensor-parallel
# size.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0

View File

@ -3,6 +3,7 @@
import functools
import operator
import time
import weakref
from typing import ClassVar
import regex as re
@ -28,6 +29,7 @@ class VllmInductorPass(InductorPass):
"""Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig):
self.compilation_config = weakref.proxy(config.compilation_config)
self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None