mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[compile] Enable sequence parallelism for full cuda graph without specifying compile sizes (#26681)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
parent
3e051bda82
commit
b59dd19b55
@ -431,8 +431,15 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
|||||||
|
|
||||||
self.dump_patterns(config, self.patterns)
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
def is_applicable_for_shape(self, shape: int | None) -> bool:
|
def is_applicable(self, shape: int | None) -> bool:
|
||||||
# only do replace for specific shapes
|
# This pass is applied on top of the sequence parallelism pass.
|
||||||
|
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||||
|
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||||
|
if (
|
||||||
|
not self.compilation_config.splitting_ops
|
||||||
|
or self.compilation_config.use_inductor_graph_partition
|
||||||
|
):
|
||||||
|
return True
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
return shape is not None and shape % tp_size == 0
|
return shape is not None and shape % tp_size == 0
|
||||||
|
|
||||||
|
|||||||
@ -96,7 +96,7 @@ class InductorPass(CustomGraphPass):
|
|||||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||||
return hashlib.sha256(encoded).hexdigest()
|
return hashlib.sha256(encoded).hexdigest()
|
||||||
|
|
||||||
def is_applicable_for_shape(self, shape: int | None):
|
def is_applicable(self, shape: int | None):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -71,9 +71,11 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
|
|
||||||
shape = get_pass_context().runtime_shape
|
shape = get_pass_context().runtime_shape
|
||||||
for pass_ in self.passes:
|
for pass_ in self.passes:
|
||||||
if pass_.is_applicable_for_shape(shape):
|
if pass_.is_applicable(shape):
|
||||||
pass_(graph)
|
pass_(graph)
|
||||||
VllmInductorPass.dump_prefix += 1
|
VllmInductorPass.dump_prefix += 1
|
||||||
|
else:
|
||||||
|
logger.debug("Skipping %s with shape %s", pass_, shape)
|
||||||
|
|
||||||
# post-cleanup goes before fix_functionalization
|
# post-cleanup goes before fix_functionalization
|
||||||
# because it requires a functional graph
|
# because it requires a functional graph
|
||||||
|
|||||||
@ -482,7 +482,25 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
|||||||
).register(self.patterns)
|
).register(self.patterns)
|
||||||
self.dump_patterns(config, self.patterns)
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
def is_applicable_for_shape(self, shape: int | None) -> bool:
|
def is_applicable(self, shape: int | None) -> bool:
|
||||||
|
# When sequence parallelism is enabled, the residual tensor from RMSNorm
|
||||||
|
# needs to be split along the sequence dimension. However, this dimension
|
||||||
|
# is symbolic during piecewise compilation, and splitting symbolic shapes
|
||||||
|
# is not supported.
|
||||||
|
#
|
||||||
|
# This pass is therefore only applied when the sequence dimension is
|
||||||
|
# concrete:
|
||||||
|
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
|
||||||
|
# For this case we always pad num_tokens to be a multiple of
|
||||||
|
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
|
||||||
|
# 2. For specific shape provided during compilation (e.g., from
|
||||||
|
# `compile_sizes`), which must be divisible by the tensor-parallel
|
||||||
|
# size.
|
||||||
|
if (
|
||||||
|
not self.compilation_config.splitting_ops
|
||||||
|
or self.compilation_config.use_inductor_graph_partition
|
||||||
|
):
|
||||||
|
return True
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
return shape is not None and shape % tp_size == 0
|
return shape is not None and shape % tp_size == 0
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import operator
|
import operator
|
||||||
import time
|
import time
|
||||||
|
import weakref
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -28,6 +29,7 @@ class VllmInductorPass(InductorPass):
|
|||||||
"""Keep track of pass index for debug dump ordering."""
|
"""Keep track of pass index for debug dump ordering."""
|
||||||
|
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
|
self.compilation_config = weakref.proxy(config.compilation_config)
|
||||||
self.pass_config = config.compilation_config.pass_config
|
self.pass_config = config.compilation_config.pass_config
|
||||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||||
self.device = config.device_config.device if config.device_config else None
|
self.device = config.device_config.device if config.device_config else None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user