[Compile] Conditional compilation. Introduce compile_ranges (#24252)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
Ilya Markov 2025-12-05 19:17:32 +01:00 committed by GitHub
parent 66e674cdd5
commit 4e26d3b09e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 582 additions and 268 deletions

View File

@ -298,10 +298,14 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text, log_holder.text,
) )
assert len(log_matches) == 2, log_holder.text # 2 for each compile range
# (global compile range can be split due to fuse_allreduce_rmsnorm)
num_compile_ranges = len(compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2]
assert int(log_matches[0]) == matches.attention_fusion assert len(log_matches) == 2 * num_compile_ranges, log_holder.text
assert int(log_matches[1]) == matches.attention_fusion
assert all(int(log_match) == matches.attention_fusion for log_match in log_matches)
log_matches = re.findall( log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns", r"collective_fusion.py:\d+] Replaced (\d+) patterns",
@ -312,6 +316,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
assert int(log_matches[0]) == matches.allreduce_fusion assert int(log_matches[0]) == matches.allreduce_fusion
assert int(log_matches[1]) == matches.allreduce_fusion assert int(log_matches[1]) == matches.allreduce_fusion
log_matches = re.findall(
r"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range",
log_holder.text,
)
assert len(log_matches) == 2 * (num_compile_ranges - 1), log_holder.text
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -446,7 +456,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
# No cudagraphs by default # No cudagraphs by default
if compilation_config.cudagraph_mode is None: if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM( llm = LLM(
model=model, model=model,
compilation_config=compilation_config, compilation_config=compilation_config,
@ -459,3 +468,9 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Get the compile ranges split points after vllm config post init
# in order to compute compile ranges correctly
compilation_config.compile_ranges_split_points = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
)

View File

@ -0,0 +1,168 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from torch import fx as fx
from torch import nn
# This import automatically registers `torch.ops.silly.attention`
import tests.compile.silly_attention # noqa
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.inductor_pass import (
InductorPass,
get_pass_context,
)
from vllm.config import (
VllmConfig,
set_current_vllm_config,
)
from vllm.config.compilation import CompilationConfig, CompilationMode
from vllm.config.scheduler import SchedulerConfig
from vllm.config.utils import Range
from vllm.forward_context import set_forward_context
BATCH_SIZE = 64
MLP_SIZE = 128
@support_torch_compile
class TestModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x
@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
with set_forward_context({}, vllm_config=vllm_config):
model(torch.randn(BATCH_SIZE, MLP_SIZE))
for batch_size in batch_sizes:
model(torch.randn(batch_size, MLP_SIZE))
class PostGradRangeChecker(InductorPass):
def __init__(self, ranges: list[Range]):
self.ranges = ranges
self.num_calls = 0
def __call__(self, graph: fx.Graph):
compile_range = get_pass_context().compile_range
assert compile_range in self.ranges, (
f"Compile range {compile_range} not in {self.ranges}"
)
self.num_calls += 1
def uuid(self) -> str:
state: dict[str, Any] = {}
return InductorPass.hash_dict(state)
def test_compile_ranges(use_fresh_inductor_cache):
post_grad_range_checker = PostGradRangeChecker(
[
Range(start=1, end=8),
Range(start=16, end=16),
Range(start=9, end=32),
Range(start=64, end=64),
Range(start=33, end=8192),
]
)
torch.set_default_device("cuda")
vllm_config = VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8, 32],
compile_sizes=[16, 64, 128],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
},
),
)
with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval()
# Number of compilations: 3 for each compile range + 2 compile sizes
batch_sizes = [1, 4, 16, 24, 48, 64, 8192]
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=5,
):
run_model(vllm_config, model, batch_sizes)
assert post_grad_range_checker.num_calls == 5
def test_compile_config_get_compile_ranges():
compilation_config = CompilationConfig(
compile_ranges_split_points=[8, 32],
)
VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
),
compilation_config=compilation_config,
)
assert compilation_config.get_compile_ranges() == [
Range(start=1, end=8),
Range(start=9, end=32),
Range(start=33, end=8192),
]
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
# To force multiple compilations, we disable the compile cache
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
post_grad_range_checker = PostGradRangeChecker(
ranges=[
Range(start=1, end=8),
Range(start=9, end=8192),
]
)
scheduler_config = SchedulerConfig(
max_num_batched_tokens=8192,
)
torch.set_default_device("cuda")
def create_vllm_config():
return VllmConfig(
scheduler_config=scheduler_config,
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_split_points=[8],
inductor_compile_config={
"post_grad_custom_post_pass": post_grad_range_checker,
},
),
)
vllm_config_1 = create_vllm_config()
with set_current_vllm_config(vllm_config_1):
model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval()
batch_sizes = [1, 16]
run_model(vllm_config_1, model1, batch_sizes)
assert post_grad_range_checker.num_calls == 2
post_grad_range_checker.num_calls = 0
# Create a new vllm config with the new pass context
vllm_config_2 = create_vllm_config()
with set_current_vllm_config(vllm_config_2):
model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval()
batch_sizes = [4, 32]
run_model(vllm_config_2, model2, batch_sizes)
# Check that cache is used, so the number of calls
# should be 0
assert post_grad_range_checker.num_calls == 0

View File

@ -67,6 +67,9 @@ from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from torch._inductor.utils import fresh_cache
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.generation.utils import GenerateOutput from transformers.generation.utils import GenerateOutput
@ -1465,3 +1468,14 @@ def clean_gpu_memory_between_tests():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
@pytest.fixture
def use_fresh_inductor_cache():
"""
Use a fresh inductor cache for the test.
This is useful to ensure that the test is not affected by the
previous test calls.
"""
with fresh_cache():
yield

View File

@ -26,7 +26,7 @@ from vllm.compilation.partition_rules import (
should_split, should_split,
) )
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.utils import hash_factors from vllm.config.utils import Range, hash_factors
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logging_utils import lazy from vllm.logging_utils import lazy
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -90,7 +90,7 @@ class CompilerManager:
""" """
def __init__(self, compilation_config: CompilationConfig): def __init__(self, compilation_config: CompilationConfig):
self.cache: dict[tuple[int | None, int, str], Any] = dict() self.cache: dict[tuple[Range, int, str], Any] = dict()
self.is_cache_updated = False self.is_cache_updated = False
self.compilation_config = compilation_config self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config) self.compiler = make_compiler(compilation_config)
@ -99,11 +99,11 @@ class CompilerManager:
return self.compiler.compute_hash(vllm_config) return self.compiler.compute_hash(vllm_config)
@contextmanager @contextmanager
def compile_context(self, runtime_shape: int | None = None): def compile_context(self, compile_range: Range):
"""Provide compilation context for the duration of compilation to set """Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context).""" compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape): with pass_context(compile_range):
if self.compilation_config.use_inductor_graph_partition: if self.compilation_config.use_inductor_graph_partition:
with inductor_partition_rule_context( with inductor_partition_rule_context(
self.compilation_config.splitting_ops self.compilation_config.splitting_ops
@ -159,29 +159,21 @@ class CompilerManager:
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
runtime_shape: int | None = None, compile_range: Range,
) -> Callable | None: ) -> Callable | None:
if (runtime_shape, graph_index, self.compiler.name) not in self.cache: if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None return None
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] handle = self.cache[(compile_range, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load( compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, runtime_shape handle, graph, example_inputs, graph_index, compile_range
)
logger.debug(
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
graph_index,
str(compile_range),
self.compiler.name,
handle,
) )
if runtime_shape is None:
logger.debug(
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via handle %s",
graph_index,
str(runtime_shape),
self.compiler.name,
handle,
)
return compiled_graph return compiled_graph
def compile( def compile(
@ -190,9 +182,9 @@ class CompilerManager:
example_inputs, example_inputs,
additional_inductor_config, additional_inductor_config,
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
compile_range: Range,
graph_index: int = 0, graph_index: int = 0,
num_graphs: int = 1, num_graphs: int = 1,
runtime_shape: int | None = None,
) -> Any: ) -> Any:
if graph_index == 0: if graph_index == 0:
# before compiling the first graph, record the start time # before compiling the first graph, record the start time
@ -204,7 +196,7 @@ class CompilerManager:
compiled_graph = None compiled_graph = None
# try to load from the cache # try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
if compiled_graph is not None: if compiled_graph is not None:
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time. # after loading the last graph for this shape, record the time.
@ -212,19 +204,12 @@ class CompilerManager:
now = time.time() now = time.time()
elapsed = now - compilation_start_time elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed compilation_config.compilation_time += elapsed
if runtime_shape is None: logger.info(
logger.info( "Directly load the compiled graph(s) for compile range %s "
"Directly load the compiled graph(s) for dynamic shape " "from the cache, took %.3f s",
"from the cache, took %.3f s", str(compile_range),
elapsed, elapsed,
) )
else:
logger.info(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s",
str(runtime_shape),
elapsed,
)
return compiled_graph return compiled_graph
# no compiler cached the graph, or the cache is disabled, # no compiler cached the graph, or the cache is disabled,
@ -233,14 +218,15 @@ class CompilerManager:
# Let compile_fx generate a key for us # Let compile_fx generate a key for us
maybe_key = None maybe_key = None
else: else:
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" maybe_key = "artifact_compile_range_"
maybe_key += f"{compile_range.start}_{compile_range.end}"
with self.compile_context(runtime_shape): maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range):
compiled_graph, handle = self.compiler.compile( compiled_graph, handle = self.compiler.compile(
graph, graph,
example_inputs, example_inputs,
additional_inductor_config, additional_inductor_config,
runtime_shape, compile_range,
maybe_key, maybe_key,
) )
@ -248,55 +234,34 @@ class CompilerManager:
# store the artifact in the cache # store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None: if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle self.cache[(compile_range, graph_index, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1 compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True self.is_cache_updated = True
if graph_index == 0: if graph_index == 0:
# adds some info logging for the first graph # adds some info logging for the first graph
if runtime_shape is None: logger.info_once(
logger.info_once( "Cache the graph of compile range %s for later use",
"Cache the graph for dynamic shape for later use", scope="local" str(compile_range),
)
else:
logger.info_once(
"Cache the graph of shape %s for later use",
str(runtime_shape),
scope="local",
)
if runtime_shape is None:
logger.debug(
"Store the %s-th graph for dynamic shape from %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Store the %s-th graph for shape %s from %s via handle %s",
graph_index,
str(runtime_shape),
self.compiler.name,
handle,
) )
logger.debug(
"Store the %s-th graph for compile range%s from %s via handle %s",
graph_index,
str(compile_range),
self.compiler.name,
handle,
)
# after compiling the last graph, record the end time # after compiling the last graph, record the end time
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
now = time.time() now = time.time()
elapsed = now - compilation_start_time elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed compilation_config.compilation_time += elapsed
if runtime_shape is None: logger.info_once(
logger.info_once( "Compiling a graph for compile range %s takes %.2f s",
"Compiling a graph for dynamic shape takes %.2f s", str(compile_range),
elapsed, elapsed,
scope="local", scope="local",
) )
else:
logger.info_once(
"Compiling a graph for shape %s takes %.2f s",
runtime_shape,
elapsed,
scope="local",
)
return compiled_graph return compiled_graph
@ -427,19 +392,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
sym_shape_indices = [ sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt) i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
] ]
global compilation_start_time
compiled_graph_for_dynamic_shape = (
self.vllm_backend.compiler_manager.compile(
submod,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
)
)
# Lazy import here to avoid circular import # Lazy import here to avoid circular import
from .piecewise_backend import PiecewiseBackend from .piecewise_backend import PiecewiseBackend
@ -449,7 +402,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
index, index,
len(self.compile_submod_names), len(self.compile_submod_names),
sym_shape_indices, sym_shape_indices,
compiled_graph_for_dynamic_shape,
self.vllm_backend, self.vllm_backend,
) )
@ -589,8 +541,13 @@ class VllmBackend:
) )
else: else:
# Config should automatically wrap all inductor passes # Config should automatically wrap all inductor passes
assert isinstance(self.inductor_config[self.pass_key], InductorPass) assert isinstance(
self.pass_manager.add(self.inductor_config[self.pass_key]) self.compilation_config.inductor_compile_config[self.pass_key],
InductorPass,
)
self.pass_manager.add(
self.compilation_config.inductor_compile_config[self.pass_key]
)
self.inductor_config[self.pass_key] = self.pass_manager self.inductor_config[self.pass_key] = self.pass_manager
def __call__( def __call__(

View File

@ -10,6 +10,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
@ -431,7 +432,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool: def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass. # This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`. # It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details. # See `SequenceParallelismPass.is_applicable` for more details.
@ -441,7 +442,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
): ):
return True return True
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0 return compile_range.is_single_size() and compile_range.end % tp_size == 0
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
@ -505,91 +506,60 @@ if flashinfer_comm is not None:
num_tokens, hidden_size = allreduce_in.shape num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size() element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size current_tensor_size = num_tokens * hidden_size * element_size
max_tensor_size = max_token_num * hidden_size * element_size
assert current_tensor_size <= max_tensor_size, (
f"Current tensor size {current_tensor_size} is larger than "
f"max token num {max_token_num} * hidden size {hidden_size} * "
f"element size {element_size}"
)
device_capability = current_platform.get_device_capability().to_int()
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, {}
).get(world_size, None)
# Use one shot if no max size is specified
use_oneshot = (
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
)
if num_tokens <= max_token_num: assert _FI_WORKSPACE_TENSOR is not None, (
device_capability = current_platform.get_device_capability().to_int() "Flashinfer must be enabled when using flashinfer"
# Get one shot input size limit for the current world size )
# for the current device capability if norm_out is None:
max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( norm_out = allreduce_in
device_capability, {} residual_out = residual
).get(world_size, None)
# Use one shot if no max size for one shot is specified
use_oneshot = (
max_one_shot_size_mb is None
or current_tensor_size <= max_one_shot_size_mb * MiB
)
assert _FI_WORKSPACE_TENSOR is not None, (
"Flashinfer must be enabled when using flashinfer"
)
if norm_out is None:
norm_out = allreduce_in
residual_out = residual
else:
# return residual_out as allreduce_out with zeroed residual_in
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out = allreduce_in
# For the sizes that are smaller than the max size,
# we only use flashinfer one shot allreduce
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=allreduce_in,
token_num=allreduce_in.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
world_rank=world_rank,
world_size=world_size,
hidden_dim=allreduce_in.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
launch_with_pdl=launch_with_pdl,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=pattern_code,
allreduce_out=None,
quant_out=quant_out,
scale_out=scale_out,
# in vllm we only support swizzled layout
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=scale_factor,
)
else: else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) # return residual_out as allreduce_out with zeroed residual_in
if scale_factor is not None and scale_out is None: # as flashinfer does not support rms_norm
# Do fused rms norm static fp8 quant fused op # and allreduce_out together
if norm_out is None: residual_out = allreduce_in
torch.ops._C.fused_add_rms_norm_static_fp8_quant( # For the sizes that are smaller than the max size,
quant_out, # we only use flashinfer one shot allreduce
allreduce_out, flashinfer_comm.trtllm_allreduce_fusion(
residual, allreduce_in=allreduce_in,
rms_gamma, token_num=allreduce_in.shape[0],
scale_factor, residual_in=residual,
rms_eps, residual_out=residual_out,
) norm_out=norm_out,
else: rms_gamma=rms_gamma,
torch.ops._C.rms_norm_static_fp8_quant( rms_eps=rms_eps,
quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps world_rank=world_rank,
) world_size=world_size,
else: hidden_dim=allreduce_in.shape[-1],
if norm_out is None: workspace_ptrs=_FI_WORKSPACE_TENSOR,
torch.ops._C.fused_add_rms_norm( launch_with_pdl=launch_with_pdl,
allreduce_out, residual, rms_gamma, rms_eps use_oneshot=use_oneshot,
) trigger_completion_at_end=trigger_completion_at_end,
norm_out = allreduce_out fp32_acc=fp32_acc,
else: pattern_code=pattern_code,
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) allreduce_out=None,
if scale_factor is not None and scale_out is not None: quant_out=quant_out,
torch.ops._C.scaled_fp4_quant( scale_out=scale_out,
quant_out, norm_out, scale_out, scale_factor # in vllm we only support swizzled layout
) layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
if scale_factor is None or norm_out is not None: scale_factor=scale_factor,
# we need to return allreduce output )
# in cases of non quant fused AR + RMS norm
# and fused AR + RMS norm + quant without fused add
allreduce_in.copy_(allreduce_out)
def call_trtllm_fused_allreduce_norm_fake( def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor, allreduce_in: torch.Tensor,
@ -1128,7 +1098,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
if max_size is None: if max_size is None:
# Flashinfer doesn't support current world size # Flashinfer doesn't support current world size
logger.warning( logger.warning(
"Flashinfer allreduce fusion is not supported for world size %s", "Flashinfer allreduce fusion is not supported for world size %s"
" or max size is not provided",
self.tp_size, self.tp_size,
) )
return return
@ -1216,6 +1187,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.disabled = False self.disabled = False
def is_applicable_for_range(self, compile_range: Range) -> bool:
return compile_range.end <= self.max_token_num
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
if self.disabled: if self.disabled:

View File

@ -15,6 +15,7 @@ import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -63,16 +64,16 @@ class CompilerInterface:
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
runtime_shape: int | None = None, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable | None, Any | None]:
""" """
Compile the graph with the given example inputs and compiler config, Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means with a range. The `compile_range` specifies the range of the inputs,
the `example_inputs` have a dynamic shape. Otherwise, the it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
`runtime_shape` specifies the shape of the inputs. Right now we only or a range [5, 8].
support one variable shape for all inputs, which is the batchsize Right now we only support one variable in ranges for all inputs,
(number of tokens) during inference. which is the batchsize (number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid. Dynamo will make sure `graph(*example_inputs)` is valid.
@ -98,7 +99,7 @@ class CompilerInterface:
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
runtime_shape: int | None = None, compile_range: Range,
) -> Callable: ) -> Callable:
""" """
Load the compiled function from the handle. Load the compiled function from the handle.
@ -212,20 +213,20 @@ class InductorStandaloneAdaptor(CompilerInterface):
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
runtime_shape: int | None = None, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable | None, Any | None]:
compilation_counter.num_inductor_compiles += 1 compilation_counter.num_inductor_compiles += 1
current_config = {} current_config = {}
if compiler_config is not None: if compiler_config is not None:
current_config.update(compiler_config) current_config.update(compiler_config)
set_inductor_config(current_config, runtime_shape) set_inductor_config(current_config, compile_range)
set_functorch_config() set_functorch_config()
if isinstance(runtime_shape, int): if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs" dynamic_shapes = "from_example_inputs"
else: else:
dynamic_shapes = "from_tracing_context" dynamic_shapes = "from_graph"
from torch._inductor import standalone_compile from torch._inductor import standalone_compile
@ -235,7 +236,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config}, options={"config_patches": current_config},
) )
# Save the compiled artifact to disk in the specified path # Save the compiled artifact to disk in the specified path
assert key is not None assert key is not None
path = os.path.join(self.cache_dir, key) path = os.path.join(self.cache_dir, key)
@ -251,7 +251,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
runtime_shape: int | None = None, compile_range: Range,
) -> Callable: ) -> Callable:
assert isinstance(handle, tuple) assert isinstance(handle, tuple)
assert isinstance(handle[0], str) assert isinstance(handle[0], str)
@ -315,7 +315,7 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
runtime_shape: int | None = None, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable | None, Any | None]:
compilation_counter.num_inductor_compiles += 1 compilation_counter.num_inductor_compiles += 1
@ -329,7 +329,7 @@ class InductorAdaptor(CompilerInterface):
current_config["fx_graph_cache"] = True current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, runtime_shape) set_inductor_config(current_config, compile_range)
set_functorch_config() set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it # inductor can inplace modify the graph, so we need to copy it
@ -512,7 +512,7 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
runtime_shape: int | None = None, compile_range: Range,
) -> Callable: ) -> Callable:
assert isinstance(handle, tuple) assert isinstance(handle, tuple)
assert isinstance(handle[0], str) assert isinstance(handle[0], str)
@ -608,9 +608,9 @@ class InductorAdaptor(CompilerInterface):
return contextlib.nullcontext() return contextlib.nullcontext()
def set_inductor_config(config, runtime_shape): def set_inductor_config(config, compile_range: Range):
if isinstance(runtime_shape, int): if compile_range.is_single_size():
# for a specific batchsize, tuning triton kernel parameters # for a specific batch size, tuning triton kernel parameters
# can be beneficial # can be beneficial
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
config["coordinate_descent_tuning"] = ( config["coordinate_descent_tuning"] = (
@ -630,7 +630,7 @@ class EagerAdaptor(CompilerInterface):
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
runtime_shape: int | None = None, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable | None, Any | None]:
compilation_counter.num_eager_compiles += 1 compilation_counter.num_eager_compiles += 1

View File

@ -14,6 +14,7 @@ import torch
from torch import fx from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.config.utils import Range
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
if is_torch_equal_or_newer("2.6"): if is_torch_equal_or_newer("2.6"):
@ -28,8 +29,8 @@ _pass_context = None
class PassContext: class PassContext:
def __init__(self, runtime_shape: int | None): def __init__(self, compile_range: Range):
self.runtime_shape = runtime_shape self.compile_range: Range = compile_range
def get_pass_context() -> PassContext: def get_pass_context() -> PassContext:
@ -39,13 +40,13 @@ def get_pass_context() -> PassContext:
@contextmanager @contextmanager
def pass_context(runtime_shape: int | None): def pass_context(compile_range: Range):
"""A context manager that stores the current pass context, """A context manager that stores the current pass context,
usually it is a list of sizes to specialize. usually it is a list of sizes to specialize.
""" """
global _pass_context global _pass_context
prev_context = _pass_context prev_context = _pass_context
_pass_context = PassContext(runtime_shape) _pass_context = PassContext(compile_range)
try: try:
yield yield
finally: finally:
@ -96,7 +97,7 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest() return hashlib.sha256(encoded).hexdigest()
def is_applicable(self, shape: int | None): def is_applicable_for_range(self, compile_range: Range):
return True return True

View File

@ -24,7 +24,11 @@ if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .noop_elimination import NoOpEliminationPass from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__) logger = init_logger(__name__)
@ -70,13 +74,13 @@ class PostGradPassManager(CustomGraphPass):
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape compile_range = get_pass_context().compile_range
for pass_ in self.passes: for pass_ in self.passes:
if pass_.is_applicable(shape): if pass_.is_applicable_for_range(compile_range):
pass_(graph) pass_(graph)
VllmInductorPass.dump_prefix += 1 VllmInductorPass.dump_prefix += 1
else: else:
logger.debug("Skipping %s with shape %s", pass_, shape) logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization # post-cleanup goes before fix_functionalization
# because it requires a functional graph # because it requires a functional graph
@ -133,4 +137,8 @@ class PostGradPassManager(CustomGraphPass):
state["passes"].append(pass_.uuid()) state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid()) state["passes"].append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
return InductorPass.hash_dict(state) return InductorPass.hash_dict(state)

View File

@ -7,18 +7,18 @@ from typing import Any
import torch.fx as fx import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class ConcreteSizeEntry: class RangeEntry:
runtime_shape: int compile_range: Range
compiled: bool = False compiled: bool = False
runnable: Callable = None # type: ignore runnable: Callable = None # type: ignore
@ -31,7 +31,6 @@ class PiecewiseBackend:
piecewise_compile_index: int, piecewise_compile_index: int,
total_piecewise_compiles: int, total_piecewise_compiles: int,
sym_shape_indices: list[int], sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend, vllm_backend: VllmBackend,
): ):
""" """
@ -55,67 +54,111 @@ class PiecewiseBackend:
self.is_full_graph = total_piecewise_compiles == 1 self.is_full_graph = total_piecewise_compiles == 1
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) self.compile_ranges = self.compilation_config.get_compile_ranges()
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)
self.first_run_finished = False self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" # the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# the entries for different shapes that we need to compile # to_be_compiled_ranges tracks the remaining ranges to compile,
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it # and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly. # We only keep compilation management inside this class directly.
for shape in self.compile_sizes: for size in self.compile_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry( range = Range(start=size, end=size)
runtime_shape=shape, if range not in self.compile_ranges:
runnable=self.compiled_graph_for_general_shape, self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
) )
def check_for_ending_compilation(self): def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile # no specific sizes to compile
# save the hash of the inductor graph for the next run # save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file() self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config) end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any: def _fakify_args(self, args: list[Any]) -> list[Any]:
if not self.first_run_finished: # We need to pass fake example_inputs, otherwise torch.compile
self.first_run_finished = True # will fakify the example_inputs potentially causing some non dynamic
self.check_for_ending_compilation() # dimension to be be duck shaped to other existing shapes that have hints
return self.compiled_graph_for_general_shape(*args) # matching their values.
# This is problem because it can lead to unintended specializations!
# if the new wrongly dynamic dim is specialized
# it will force specializing the whole shape
# torch.compile probably should not accept
# non fake tensors as example inputs!
# See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = []
for node in self.graph.graph.nodes:
# All place holders come first
if node.op == "placeholder":
fake_example_inputs.append(node.meta["example_value"])
else:
break
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
runtime_shape = args[self.sym_shape_indices[0]] def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
if not range_entry.compiled:
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape]
if not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments # args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile( # fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else args
)
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph, self.graph,
args, args,
self.vllm_backend.inductor_config, self.vllm_backend.inductor_config,
self.compilation_config, self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index, graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
) )
# finished compilations for all required shapes self.check_for_ending_compilation()
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
return entry.runnable(*args) def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None
def __call__(self, *args) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]"
)
self._maybe_compile_for_range_entry(range_entry, args)
return range_entry.runnable(*args)

View File

@ -9,6 +9,7 @@ import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool: def is_applicable_for_range(self, compile_range: Range) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm # When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension # needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes # is symbolic during piecewise compilation, and splitting symbolic shapes
@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
): ):
return True return True
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0 return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):

View File

@ -13,7 +13,13 @@ from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors from vllm.config.utils import (
Range,
config,
get_hash_factors,
handle_deprecated,
hash_factors,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
@ -173,6 +179,9 @@ class PassConfig:
""" """
MiB = 1024 * 1024 MiB = 1024 * 1024
FI_SUPPORTED_WORLD_SIZES = [2, 4, 8]
if world_size not in FI_SUPPORTED_WORLD_SIZES:
return None
max_size_mb = self.fi_allreduce_fusion_max_size_mb max_size_mb = self.fi_allreduce_fusion_max_size_mb
if max_size_mb is None: if max_size_mb is None:
max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size)
@ -379,6 +388,8 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_copy_inputs] [vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation: - Inductor compilation:
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`compile_ranges_split_points`]
[vllm.config.CompilationConfig.compile_ranges_split_points]
- [`inductor_compile_config`] - [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config] [vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
@ -492,6 +503,21 @@ class CompilationConfig:
to integers, it also supports "cudagraph_capture_sizes" to to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture.""" specify the sizes for cudagraph capture."""
compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor.
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].
If a range overlaps with the compile size, graph for compile size
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
graph for compile size 4 will be compiled and used instead of the graph
for range [1, 8].
"""
inductor_compile_config: dict = field(default_factory=dict) inductor_compile_config: dict = field(default_factory=dict)
"""Additional configurations for inductor. """Additional configurations for inductor.
- None: use default configurations.""" - None: use default configurations."""
@ -1153,3 +1179,13 @@ class CompilationConfig:
self.bs_to_padded_graph_size[bs] = start self.bs_to_padded_graph_size[bs] = start
else: else:
self.bs_to_padded_graph_size[bs] = end self.bs_to_padded_graph_size[bs] = end
def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
return []
split_points = sorted(set(self.compile_ranges_split_points))
return [
Range(start=s + 1, end=e)
for s, e in zip([0] + split_points[:-1], split_points)
]

View File

@ -10,7 +10,7 @@ import json
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Iterable, Mapping, Sequence, Set from collections.abc import Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar
@ -322,3 +322,35 @@ def handle_deprecated(
for new_name in new_names: for new_name in new_names:
setattr(config, new_name, old_val) setattr(config, new_name, old_val)
@dataclass
class Range:
"""
A range of numbers.
Inclusive of start, inclusive of end.
"""
start: int
end: int
def is_single_size(self) -> bool:
return self.start == self.end
def __contains__(self, size: int) -> bool:
# Inclusive of start, inclusive of end
return self.start <= size <= self.end
def __eq__(self, other: object) -> bool:
if not isinstance(other, Range):
return False
return self.start == other.start and self.end == other.end
def __hash__(self) -> int:
return hash((self.start, self.end))
def __str__(self) -> str:
return f"({self.start}, {self.end})"
def __repr__(self) -> str:
return self.__str__()

View File

@ -725,6 +725,8 @@ class VllmConfig:
"--kv-sharing-fast-prefill requires changes on model side for " "--kv-sharing-fast-prefill requires changes on model side for "
"correctness and to realize prefill savings. " "correctness and to realize prefill savings. "
) )
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges()
if self.model_config and self.model_config.is_encoder_decoder: if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -1126,6 +1128,52 @@ class VllmConfig:
# complete the remaining process. # complete the remaining process.
self.compilation_config.post_init_cudagraph_sizes() self.compilation_config.post_init_cudagraph_sizes()
def _set_compile_ranges(self):
"""
Set the compile ranges for the compilation config.
"""
compilation_config = self.compilation_config
computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if max_num_batched_tokens is not None:
computed_compile_ranges_split_points.append(max_num_batched_tokens)
# Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms:
tp_size = self.parallel_config.tensor_parallel_size
max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
if max_size is not None:
max_token_num = max_size // (
self.model_config.get_hidden_size()
* self.model_config.dtype.itemsize
)
if (
max_num_batched_tokens is not None
and max_token_num < max_num_batched_tokens
):
computed_compile_ranges_split_points.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below allreduce-rms fusion threshold, "
"allreduce-rms fusion will be enabled for all num_tokens."
)
if compilation_config.compile_ranges_split_points is not None:
for x in compilation_config.compile_ranges_split_points:
assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}"
if (
max_num_batched_tokens is not None
and x < max_num_batched_tokens
and x > 1
):
computed_compile_ranges_split_points.append(x)
compilation_config.compile_ranges_split_points = sorted(
computed_compile_ranges_split_points
)
def recalculate_max_model_len(self, max_model_len: int): def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config # Can only be called in try_verify_and_update_config
model_config = self.model_config model_config = self.model_config

View File

@ -15,6 +15,7 @@ import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.distributed import ( from vllm.distributed import (
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
@ -407,15 +408,31 @@ class Worker(WorkerBase):
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes, warmup_sizes = []
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill. if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
compile_sizes = self.vllm_config.compilation_config.compile_sizes # warm up sizes that are not in cudagraph capture sizes,
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] # but users still want to compile for better performance,
if not self.model_config.enforce_eager: # e.g. for the max-num-batched token size in chunked prefill.
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes compile_sizes = self.vllm_config.compilation_config.compile_sizes
if capture_sizes is not None: warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes] cg_capture_sizes: list[int] = []
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
cg_capture_sizes = [] if cg_sizes is None else cg_sizes
warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]
compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
# For each compile_range, if none of the batch sizes
# in warmup_sizes or cudagraph_capture_sizes are in the range,
# add the end of the range to ensure compilation/warmup.
all_sizes = set(cg_capture_sizes)
all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
for compile_range in compile_ranges:
if not any(x in compile_range for x in all_sizes):
warmup_sizes.append(compile_range.end)
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True): for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size) logger.info("Compile and warming up model for size %d", size)

View File

@ -337,7 +337,7 @@ def is_residual_scattered_for_sp(
The residual tensor is scattered across tensor parallel ranks when sequence The residual tensor is scattered across tensor parallel ranks when sequence
parallelism and tensor parallelism is enabled. parallelism and tensor parallelism is enabled.
This follows the same logic as SequenceParallelismPass.is_applicable(): This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
- In full-graph compilation mode (no splitting ops or using inductor graph - In full-graph compilation mode (no splitting ops or using inductor graph
partition), SP is always applied partition), SP is always applied
- Otherwise, SP is only applied for specific shapes in compile_sizes - Otherwise, SP is only applied for specific shapes in compile_sizes