mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:55:01 +08:00
[compile] Enable sequence parallelism for full cuda graph without specifying compile sizes (#26681)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
parent
3e051bda82
commit
b59dd19b55
@ -431,8 +431,15 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: int | None) -> bool:
|
||||
# only do replace for specific shapes
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ class InductorPass(CustomGraphPass):
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_shape(self, shape: int | None):
|
||||
def is_applicable(self, shape: int | None):
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@ -71,9 +71,11 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
shape = get_pass_context().runtime_shape
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_shape(shape):
|
||||
if pass_.is_applicable(shape):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
else:
|
||||
logger.debug("Skipping %s with shape %s", pass_, shape)
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
|
||||
@ -482,7 +482,25 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
).register(self.patterns)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: int | None) -> bool:
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
# When sequence parallelism is enabled, the residual tensor from RMSNorm
|
||||
# needs to be split along the sequence dimension. However, this dimension
|
||||
# is symbolic during piecewise compilation, and splitting symbolic shapes
|
||||
# is not supported.
|
||||
#
|
||||
# This pass is therefore only applied when the sequence dimension is
|
||||
# concrete:
|
||||
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
|
||||
# For this case we always pad num_tokens to be a multiple of
|
||||
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
|
||||
# 2. For specific shape provided during compilation (e.g., from
|
||||
# `compile_sizes`), which must be divisible by the tensor-parallel
|
||||
# size.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
import weakref
|
||||
from typing import ClassVar
|
||||
|
||||
import regex as re
|
||||
@ -28,6 +29,7 @@ class VllmInductorPass(InductorPass):
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.compilation_config = weakref.proxy(config.compilation_config)
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user