mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
[Compile] Conditional compilation. Introduce compile_ranges (#24252)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: ProExpertProg <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
parent
66e674cdd5
commit
4e26d3b09e
@ -298,10 +298,14 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
# 2 for each compile range
|
||||
# (global compile range can be split due to fuse_allreduce_rmsnorm)
|
||||
num_compile_ranges = len(compilation_config.get_compile_ranges())
|
||||
assert num_compile_ranges in [1, 2]
|
||||
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
assert int(log_matches[1]) == matches.attention_fusion
|
||||
assert len(log_matches) == 2 * num_compile_ranges, log_holder.text
|
||||
|
||||
assert all(int(log_match) == matches.attention_fusion for log_match in log_matches)
|
||||
|
||||
log_matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
@ -312,6 +316,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
assert int(log_matches[0]) == matches.allreduce_fusion
|
||||
assert int(log_matches[1]) == matches.allreduce_fusion
|
||||
|
||||
log_matches = re.findall(
|
||||
r"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2 * (num_compile_ranges - 1), log_holder.text
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
@ -446,7 +456,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
|
||||
# No cudagraphs by default
|
||||
if compilation_config.cudagraph_mode is None:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
compilation_config=compilation_config,
|
||||
@ -459,3 +468,9 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Get the compile ranges split points after vllm config post init
|
||||
# in order to compute compile ranges correctly
|
||||
compilation_config.compile_ranges_split_points = (
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
|
||||
)
|
||||
|
||||
168
tests/compile/test_compile_ranges.py
Normal file
168
tests/compile/test_compile_ranges.py
Normal file
@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import fx as fx
|
||||
from torch import nn
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
import tests.compile.silly_attention # noqa
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.compilation.inductor_pass import (
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from vllm.config import (
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.compilation import CompilationConfig, CompilationMode
|
||||
from vllm.config.scheduler import SchedulerConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.forward_context import set_forward_context
|
||||
|
||||
BATCH_SIZE = 64
|
||||
MLP_SIZE = 128
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TestModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
model(torch.randn(BATCH_SIZE, MLP_SIZE))
|
||||
for batch_size in batch_sizes:
|
||||
model(torch.randn(batch_size, MLP_SIZE))
|
||||
|
||||
|
||||
class PostGradRangeChecker(InductorPass):
|
||||
def __init__(self, ranges: list[Range]):
|
||||
self.ranges = ranges
|
||||
self.num_calls = 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
compile_range = get_pass_context().compile_range
|
||||
assert compile_range in self.ranges, (
|
||||
f"Compile range {compile_range} not in {self.ranges}"
|
||||
)
|
||||
self.num_calls += 1
|
||||
|
||||
def uuid(self) -> str:
|
||||
state: dict[str, Any] = {}
|
||||
return InductorPass.hash_dict(state)
|
||||
|
||||
|
||||
def test_compile_ranges(use_fresh_inductor_cache):
|
||||
post_grad_range_checker = PostGradRangeChecker(
|
||||
[
|
||||
Range(start=1, end=8),
|
||||
Range(start=16, end=16),
|
||||
Range(start=9, end=32),
|
||||
Range(start=64, end=64),
|
||||
Range(start=33, end=8192),
|
||||
]
|
||||
)
|
||||
torch.set_default_device("cuda")
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_ranges_split_points=[8, 32],
|
||||
compile_sizes=[16, 64, 128],
|
||||
inductor_compile_config={
|
||||
"post_grad_custom_post_pass": post_grad_range_checker,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = TestModel(vllm_config=vllm_config, prefix="").eval()
|
||||
# Number of compilations: 3 for each compile range + 2 compile sizes
|
||||
batch_sizes = [1, 4, 16, 24, 48, 64, 8192]
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_backend_compilations=5,
|
||||
):
|
||||
run_model(vllm_config, model, batch_sizes)
|
||||
assert post_grad_range_checker.num_calls == 5
|
||||
|
||||
|
||||
def test_compile_config_get_compile_ranges():
|
||||
compilation_config = CompilationConfig(
|
||||
compile_ranges_split_points=[8, 32],
|
||||
)
|
||||
VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
),
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert compilation_config.get_compile_ranges() == [
|
||||
Range(start=1, end=8),
|
||||
Range(start=9, end=32),
|
||||
Range(start=33, end=8192),
|
||||
]
|
||||
|
||||
|
||||
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
|
||||
# To force multiple compilations, we disable the compile cache
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
post_grad_range_checker = PostGradRangeChecker(
|
||||
ranges=[
|
||||
Range(start=1, end=8),
|
||||
Range(start=9, end=8192),
|
||||
]
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def create_vllm_config():
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_ranges_split_points=[8],
|
||||
inductor_compile_config={
|
||||
"post_grad_custom_post_pass": post_grad_range_checker,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
vllm_config_1 = create_vllm_config()
|
||||
with set_current_vllm_config(vllm_config_1):
|
||||
model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval()
|
||||
batch_sizes = [1, 16]
|
||||
run_model(vllm_config_1, model1, batch_sizes)
|
||||
assert post_grad_range_checker.num_calls == 2
|
||||
|
||||
post_grad_range_checker.num_calls = 0
|
||||
# Create a new vllm config with the new pass context
|
||||
vllm_config_2 = create_vllm_config()
|
||||
with set_current_vllm_config(vllm_config_2):
|
||||
model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval()
|
||||
batch_sizes = [4, 32]
|
||||
run_model(vllm_config_2, model2, batch_sizes)
|
||||
# Check that cache is used, so the number of calls
|
||||
# should be 0
|
||||
assert post_grad_range_checker.num_calls == 0
|
||||
@ -67,6 +67,9 @@ from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers.generation.utils import GenerateOutput
|
||||
@ -1465,3 +1468,14 @@ def clean_gpu_memory_between_tests():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def use_fresh_inductor_cache():
|
||||
"""
|
||||
Use a fresh inductor cache for the test.
|
||||
This is useful to ensure that the test is not affected by the
|
||||
previous test calls.
|
||||
"""
|
||||
with fresh_cache():
|
||||
yield
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.compilation.partition_rules import (
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.config.utils import Range, hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy
|
||||
from vllm.platforms import current_platform
|
||||
@ -90,7 +90,7 @@ class CompilerManager:
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig):
|
||||
self.cache: dict[tuple[int | None, int, str], Any] = dict()
|
||||
self.cache: dict[tuple[Range, int, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
@ -99,11 +99,11 @@ class CompilerManager:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
@contextmanager
|
||||
def compile_context(self, runtime_shape: int | None = None):
|
||||
def compile_context(self, compile_range: Range):
|
||||
"""Provide compilation context for the duration of compilation to set
|
||||
any torch global properties we want to scope to a single Inductor
|
||||
compilation (e.g. partition rules, pass context)."""
|
||||
with pass_context(runtime_shape):
|
||||
with pass_context(compile_range):
|
||||
if self.compilation_config.use_inductor_graph_partition:
|
||||
with inductor_partition_rule_context(
|
||||
self.compilation_config.splitting_ops
|
||||
@ -159,29 +159,21 @@ class CompilerManager:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
) -> Callable | None:
|
||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||
if (compile_range, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(
|
||||
handle, graph, example_inputs, graph_index, runtime_shape
|
||||
handle, graph, example_inputs, graph_index, compile_range
|
||||
)
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
|
||||
graph_index,
|
||||
str(compile_range),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
def compile(
|
||||
@ -190,9 +182,9 @@ class CompilerManager:
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
compile_range: Range,
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
@ -204,7 +196,7 @@ class CompilerManager:
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == num_graphs - 1:
|
||||
# after loading the last graph for this shape, record the time.
|
||||
@ -212,19 +204,12 @@ class CompilerManager:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for dynamic shape "
|
||||
"from the cache, took %.3f s",
|
||||
elapsed,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for shape %s "
|
||||
"from the cache, took %.3f s",
|
||||
str(runtime_shape),
|
||||
elapsed,
|
||||
)
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for compile range %s "
|
||||
"from the cache, took %.3f s",
|
||||
str(compile_range),
|
||||
elapsed,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
@ -233,14 +218,15 @@ class CompilerManager:
|
||||
# Let compile_fx generate a key for us
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
|
||||
with self.compile_context(runtime_shape):
|
||||
maybe_key = "artifact_compile_range_"
|
||||
maybe_key += f"{compile_range.start}_{compile_range.end}"
|
||||
maybe_key += f"_subgraph_{graph_index}"
|
||||
with self.compile_context(compile_range):
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
runtime_shape,
|
||||
compile_range,
|
||||
maybe_key,
|
||||
)
|
||||
|
||||
@ -248,55 +234,34 @@ class CompilerManager:
|
||||
|
||||
# store the artifact in the cache
|
||||
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
|
||||
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
|
||||
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
if runtime_shape is None:
|
||||
logger.info_once(
|
||||
"Cache the graph for dynamic shape for later use", scope="local"
|
||||
)
|
||||
else:
|
||||
logger.info_once(
|
||||
"Cache the graph of shape %s for later use",
|
||||
str(runtime_shape),
|
||||
scope="local",
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for dynamic shape from %s via handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
logger.info_once(
|
||||
"Cache the graph of compile range %s for later use",
|
||||
str(compile_range),
|
||||
)
|
||||
logger.debug(
|
||||
"Store the %s-th graph for compile range%s from %s via handle %s",
|
||||
graph_index,
|
||||
str(compile_range),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info_once(
|
||||
"Compiling a graph for dynamic shape takes %.2f s",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
else:
|
||||
logger.info_once(
|
||||
"Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape,
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
logger.info_once(
|
||||
"Compiling a graph for compile range %s takes %.2f s",
|
||||
str(compile_range),
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
@ -427,19 +392,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
global compilation_start_time
|
||||
|
||||
compiled_graph_for_dynamic_shape = (
|
||||
self.vllm_backend.compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None,
|
||||
)
|
||||
)
|
||||
# Lazy import here to avoid circular import
|
||||
from .piecewise_backend import PiecewiseBackend
|
||||
|
||||
@ -449,7 +402,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
index,
|
||||
len(self.compile_submod_names),
|
||||
sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape,
|
||||
self.vllm_backend,
|
||||
)
|
||||
|
||||
@ -589,8 +541,13 @@ class VllmBackend:
|
||||
)
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(self.inductor_config[self.pass_key], InductorPass)
|
||||
self.pass_manager.add(self.inductor_config[self.pass_key])
|
||||
assert isinstance(
|
||||
self.compilation_config.inductor_compile_config[self.pass_key],
|
||||
InductorPass,
|
||||
)
|
||||
self.pass_manager.add(
|
||||
self.compilation_config.inductor_compile_config[self.pass_key]
|
||||
)
|
||||
self.inductor_config[self.pass_key] = self.pass_manager
|
||||
|
||||
def __call__(
|
||||
|
||||
@ -10,6 +10,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@ -431,7 +432,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
@ -441,7 +442,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
return compile_range.is_single_size() and compile_range.end % tp_size == 0
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
@ -505,91 +506,60 @@ if flashinfer_comm is not None:
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
current_tensor_size = num_tokens * hidden_size * element_size
|
||||
max_tensor_size = max_token_num * hidden_size * element_size
|
||||
assert current_tensor_size <= max_tensor_size, (
|
||||
f"Current tensor size {current_tensor_size} is larger than "
|
||||
f"max token num {max_token_num} * hidden size {hidden_size} * "
|
||||
f"element size {element_size}"
|
||||
)
|
||||
device_capability = current_platform.get_device_capability().to_int()
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, {}
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size is specified
|
||||
use_oneshot = (
|
||||
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
|
||||
)
|
||||
|
||||
if num_tokens <= max_token_num:
|
||||
device_capability = current_platform.get_device_capability().to_int()
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, {}
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size for one shot is specified
|
||||
use_oneshot = (
|
||||
max_one_shot_size_mb is None
|
||||
or current_tensor_size <= max_one_shot_size_mb * MiB
|
||||
)
|
||||
|
||||
assert _FI_WORKSPACE_TENSOR is not None, (
|
||||
"Flashinfer must be enabled when using flashinfer"
|
||||
)
|
||||
if norm_out is None:
|
||||
norm_out = allreduce_in
|
||||
residual_out = residual
|
||||
else:
|
||||
# return residual_out as allreduce_out with zeroed residual_in
|
||||
# as flashinfer does not support rms_norm
|
||||
# and allreduce_out together
|
||||
residual_out = allreduce_in
|
||||
# For the sizes that are smaller than the max size,
|
||||
# we only use flashinfer one shot allreduce
|
||||
flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=allreduce_in,
|
||||
token_num=allreduce_in.shape[0],
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
world_rank=world_rank,
|
||||
world_size=world_size,
|
||||
hidden_dim=allreduce_in.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
use_oneshot=use_oneshot,
|
||||
trigger_completion_at_end=trigger_completion_at_end,
|
||||
fp32_acc=fp32_acc,
|
||||
pattern_code=pattern_code,
|
||||
allreduce_out=None,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
assert _FI_WORKSPACE_TENSOR is not None, (
|
||||
"Flashinfer must be enabled when using flashinfer"
|
||||
)
|
||||
if norm_out is None:
|
||||
norm_out = allreduce_in
|
||||
residual_out = residual
|
||||
else:
|
||||
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
|
||||
if scale_factor is not None and scale_out is None:
|
||||
# Do fused rms norm static fp8 quant fused op
|
||||
if norm_out is None:
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||
quant_out,
|
||||
allreduce_out,
|
||||
residual,
|
||||
rms_gamma,
|
||||
scale_factor,
|
||||
rms_eps,
|
||||
)
|
||||
else:
|
||||
torch.ops._C.rms_norm_static_fp8_quant(
|
||||
quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps
|
||||
)
|
||||
else:
|
||||
if norm_out is None:
|
||||
torch.ops._C.fused_add_rms_norm(
|
||||
allreduce_out, residual, rms_gamma, rms_eps
|
||||
)
|
||||
norm_out = allreduce_out
|
||||
else:
|
||||
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
|
||||
if scale_factor is not None and scale_out is not None:
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
quant_out, norm_out, scale_out, scale_factor
|
||||
)
|
||||
if scale_factor is None or norm_out is not None:
|
||||
# we need to return allreduce output
|
||||
# in cases of non quant fused AR + RMS norm
|
||||
# and fused AR + RMS norm + quant without fused add
|
||||
allreduce_in.copy_(allreduce_out)
|
||||
# return residual_out as allreduce_out with zeroed residual_in
|
||||
# as flashinfer does not support rms_norm
|
||||
# and allreduce_out together
|
||||
residual_out = allreduce_in
|
||||
# For the sizes that are smaller than the max size,
|
||||
# we only use flashinfer one shot allreduce
|
||||
flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=allreduce_in,
|
||||
token_num=allreduce_in.shape[0],
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
world_rank=world_rank,
|
||||
world_size=world_size,
|
||||
hidden_dim=allreduce_in.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
use_oneshot=use_oneshot,
|
||||
trigger_completion_at_end=trigger_completion_at_end,
|
||||
fp32_acc=fp32_acc,
|
||||
pattern_code=pattern_code,
|
||||
allreduce_out=None,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
|
||||
def call_trtllm_fused_allreduce_norm_fake(
|
||||
allreduce_in: torch.Tensor,
|
||||
@ -1128,7 +1098,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
if max_size is None:
|
||||
# Flashinfer doesn't support current world size
|
||||
logger.warning(
|
||||
"Flashinfer allreduce fusion is not supported for world size %s",
|
||||
"Flashinfer allreduce fusion is not supported for world size %s"
|
||||
" or max size is not provided",
|
||||
self.tp_size,
|
||||
)
|
||||
return
|
||||
@ -1216,6 +1187,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
|
||||
self.disabled = False
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
if self.disabled:
|
||||
|
||||
@ -15,6 +15,7 @@ import torch.fx as fx
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
@ -63,16 +64,16 @@ class CompilerInterface:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||
support one variable shape for all inputs, which is the batchsize
|
||||
(number of tokens) during inference.
|
||||
with a range. The `compile_range` specifies the range of the inputs,
|
||||
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
|
||||
or a range [5, 8].
|
||||
Right now we only support one variable in ranges for all inputs,
|
||||
which is the batchsize (number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
@ -98,7 +99,7 @@ class CompilerInterface:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
@ -212,20 +213,20 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
if compile_range.is_single_size():
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_tracing_context"
|
||||
dynamic_shapes = "from_graph"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
@ -235,7 +236,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config},
|
||||
)
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
@ -251,7 +251,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
@ -315,7 +315,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
@ -329,7 +329,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
@ -512,7 +512,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
@ -608,9 +608,9 @@ class InductorAdaptor(CompilerInterface):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
def set_inductor_config(config, compile_range: Range):
|
||||
if compile_range.is_single_size():
|
||||
# for a specific batch size, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
@ -630,7 +630,7 @@ class EagerAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
|
||||
@ -14,6 +14,7 @@ import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
from vllm.config.utils import Range
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
@ -28,8 +29,8 @@ _pass_context = None
|
||||
|
||||
|
||||
class PassContext:
|
||||
def __init__(self, runtime_shape: int | None):
|
||||
self.runtime_shape = runtime_shape
|
||||
def __init__(self, compile_range: Range):
|
||||
self.compile_range: Range = compile_range
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
@ -39,13 +40,13 @@ def get_pass_context() -> PassContext:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(runtime_shape: int | None):
|
||||
def pass_context(compile_range: Range):
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(runtime_shape)
|
||||
_pass_context = PassContext(compile_range)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -96,7 +97,7 @@ class InductorPass(CustomGraphPass):
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable(self, shape: int | None):
|
||||
def is_applicable_for_range(self, compile_range: Range):
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@ -24,7 +24,11 @@ if current_platform.is_cuda():
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .inductor_pass import (
|
||||
CustomGraphPass,
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -70,13 +74,13 @@ class PostGradPassManager(CustomGraphPass):
|
||||
def __call__(self, graph: fx.Graph):
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
shape = get_pass_context().runtime_shape
|
||||
compile_range = get_pass_context().compile_range
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable(shape):
|
||||
if pass_.is_applicable_for_range(compile_range):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
else:
|
||||
logger.debug("Skipping %s with shape %s", pass_, shape)
|
||||
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
@ -133,4 +137,8 @@ class PostGradPassManager(CustomGraphPass):
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
|
||||
# Include the compile range in the uuid to ensure that inductor
|
||||
# recompiles the graph for the new dynamic compile range.
|
||||
state["compile_range"] = str(get_pass_context().compile_range)
|
||||
|
||||
return InductorPass.hash_dict(state)
|
||||
|
||||
@ -7,18 +7,18 @@ from typing import Any
|
||||
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
class RangeEntry:
|
||||
compile_range: Range
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
|
||||
@ -31,7 +31,6 @@ class PiecewiseBackend:
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend,
|
||||
):
|
||||
"""
|
||||
@ -55,67 +54,111 @@ class PiecewiseBackend:
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
|
||||
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes)
|
||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
||||
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
self.compile_sizes = self.compilation_config.compile_sizes
|
||||
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
# the entries for ranges that we need to either
|
||||
self.range_entries: dict[Range, RangeEntry] = {}
|
||||
|
||||
# the entries for different shapes that we need to compile
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# to_be_compiled_ranges tracks the remaining ranges to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
for shape in self.compile_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
runnable=self.compiled_graph_for_general_shape,
|
||||
for size in self.compile_sizes:
|
||||
range = Range(start=size, end=size)
|
||||
if range not in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
self.to_be_compiled_ranges.add(range)
|
||||
|
||||
for range in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
if self.is_last_graph and not self.to_be_compiled_ranges:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
def _fakify_args(self, args: list[Any]) -> list[Any]:
|
||||
# We need to pass fake example_inputs, otherwise torch.compile
|
||||
# will fakify the example_inputs potentially causing some non dynamic
|
||||
# dimension to be be duck shaped to other existing shapes that have hints
|
||||
# matching their values.
|
||||
# This is problem because it can lead to unintended specializations!
|
||||
# if the new wrongly dynamic dim is specialized
|
||||
# it will force specializing the whole shape
|
||||
# torch.compile probably should not accept
|
||||
# non fake tensors as example inputs!
|
||||
# See issue https://github.com/vllm-project/vllm/issues/27899
|
||||
fake_example_inputs = []
|
||||
for node in self.graph.graph.nodes:
|
||||
# All place holders come first
|
||||
if node.op == "placeholder":
|
||||
fake_example_inputs.append(node.meta["example_value"])
|
||||
else:
|
||||
break
|
||||
assert len(fake_example_inputs) == len(args)
|
||||
return fake_example_inputs
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
|
||||
if not range_entry.compiled:
|
||||
range_entry.compiled = True
|
||||
self.to_be_compiled_ranges.remove(range_entry.compile_range)
|
||||
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
# compiler_manager.compile() so no need to fakify.
|
||||
args = (
|
||||
self._fakify_args(args)
|
||||
if not range_entry.compile_range.is_single_size()
|
||||
else args
|
||||
)
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape,
|
||||
)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
return entry.runnable(*args)
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
|
||||
# First we try to find the range entry for the concrete compile size
|
||||
# If not found, we search for the range entry
|
||||
# that contains the runtime shape.
|
||||
if runtime_shape in self.compile_sizes:
|
||||
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
|
||||
else:
|
||||
for range in self.compile_ranges:
|
||||
if runtime_shape in range:
|
||||
return self.range_entries[range]
|
||||
return None
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
range_entry = self._find_range_for_shape(runtime_shape)
|
||||
|
||||
assert range_entry is not None, (
|
||||
f"Shape out of considered range: {runtime_shape} "
|
||||
"[1, max_num_batched_tokens]"
|
||||
)
|
||||
|
||||
self._maybe_compile_for_range_entry(range_entry, args)
|
||||
return range_entry.runnable(*args)
|
||||
|
||||
@ -9,6 +9,7 @@ import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# When sequence parallelism is enabled, the residual tensor from RMSNorm
|
||||
# needs to be split along the sequence dimension. However, this dimension
|
||||
# is symbolic during piecewise compilation, and splitting symbolic shapes
|
||||
@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
|
||||
@ -13,7 +13,13 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors
|
||||
from vllm.config.utils import (
|
||||
Range,
|
||||
config,
|
||||
get_hash_factors,
|
||||
handle_deprecated,
|
||||
hash_factors,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@ -173,6 +179,9 @@ class PassConfig:
|
||||
"""
|
||||
|
||||
MiB = 1024 * 1024
|
||||
FI_SUPPORTED_WORLD_SIZES = [2, 4, 8]
|
||||
if world_size not in FI_SUPPORTED_WORLD_SIZES:
|
||||
return None
|
||||
max_size_mb = self.fi_allreduce_fusion_max_size_mb
|
||||
if max_size_mb is None:
|
||||
max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size)
|
||||
@ -379,6 +388,8 @@ class CompilationConfig:
|
||||
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
|
||||
- Inductor compilation:
|
||||
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
|
||||
- [`compile_ranges_split_points`]
|
||||
[vllm.config.CompilationConfig.compile_ranges_split_points]
|
||||
- [`inductor_compile_config`]
|
||||
[vllm.config.CompilationConfig.inductor_compile_config]
|
||||
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
|
||||
@ -492,6 +503,21 @@ class CompilationConfig:
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture."""
|
||||
|
||||
compile_ranges_split_points: list[int] | None = None
|
||||
"""Split points that represent compile ranges for inductor.
|
||||
The compile ranges are
|
||||
[1, split_points[0]],
|
||||
[split_points[0] + 1, split_points[1]], ...,
|
||||
[split_points[-1] + 1, max_num_batched_tokens].
|
||||
Compile sizes are also used single element ranges,
|
||||
the range is represented as [compile_sizes[i], compile_sizes[i]].
|
||||
|
||||
If a range overlaps with the compile size, graph for compile size
|
||||
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
|
||||
graph for compile size 4 will be compiled and used instead of the graph
|
||||
for range [1, 8].
|
||||
"""
|
||||
|
||||
inductor_compile_config: dict = field(default_factory=dict)
|
||||
"""Additional configurations for inductor.
|
||||
- None: use default configurations."""
|
||||
@ -1153,3 +1179,13 @@ class CompilationConfig:
|
||||
self.bs_to_padded_graph_size[bs] = start
|
||||
else:
|
||||
self.bs_to_padded_graph_size[bs] = end
|
||||
|
||||
def get_compile_ranges(self) -> list[Range]:
|
||||
"""Get the compile ranges for the compilation config."""
|
||||
if self.compile_ranges_split_points is None:
|
||||
return []
|
||||
split_points = sorted(set(self.compile_ranges_split_points))
|
||||
return [
|
||||
Range(start=s + 1, end=e)
|
||||
for s, e in zip([0] + split_points[:-1], split_points)
|
||||
]
|
||||
|
||||
@ -10,7 +10,7 @@ import json
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Iterable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
|
||||
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
|
||||
@ -322,3 +322,35 @@ def handle_deprecated(
|
||||
|
||||
for new_name in new_names:
|
||||
setattr(config, new_name, old_val)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
"""
|
||||
A range of numbers.
|
||||
Inclusive of start, inclusive of end.
|
||||
"""
|
||||
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def is_single_size(self) -> bool:
|
||||
return self.start == self.end
|
||||
|
||||
def __contains__(self, size: int) -> bool:
|
||||
# Inclusive of start, inclusive of end
|
||||
return self.start <= size <= self.end
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Range):
|
||||
return False
|
||||
return self.start == other.start and self.end == other.end
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.start, self.end))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({self.start}, {self.end})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
@ -725,6 +725,8 @@ class VllmConfig:
|
||||
"--kv-sharing-fast-prefill requires changes on model side for "
|
||||
"correctness and to realize prefill savings. "
|
||||
)
|
||||
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
|
||||
self._set_compile_ranges()
|
||||
|
||||
if self.model_config and self.model_config.is_encoder_decoder:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -1126,6 +1128,52 @@ class VllmConfig:
|
||||
# complete the remaining process.
|
||||
self.compilation_config.post_init_cudagraph_sizes()
|
||||
|
||||
def _set_compile_ranges(self):
|
||||
"""
|
||||
Set the compile ranges for the compilation config.
|
||||
"""
|
||||
compilation_config = self.compilation_config
|
||||
computed_compile_ranges_split_points = []
|
||||
|
||||
# The upper bound of the compile ranges is the max_num_batched_tokens
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
if max_num_batched_tokens is not None:
|
||||
computed_compile_ranges_split_points.append(max_num_batched_tokens)
|
||||
|
||||
# Add the compile ranges for flashinfer
|
||||
if compilation_config.pass_config.fuse_allreduce_rms:
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
|
||||
if max_size is not None:
|
||||
max_token_num = max_size // (
|
||||
self.model_config.get_hidden_size()
|
||||
* self.model_config.dtype.itemsize
|
||||
)
|
||||
if (
|
||||
max_num_batched_tokens is not None
|
||||
and max_token_num < max_num_batched_tokens
|
||||
):
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below allreduce-rms fusion threshold, "
|
||||
"allreduce-rms fusion will be enabled for all num_tokens."
|
||||
)
|
||||
|
||||
if compilation_config.compile_ranges_split_points is not None:
|
||||
for x in compilation_config.compile_ranges_split_points:
|
||||
assert isinstance(x, int)
|
||||
assert x > 0, f"Invalid compile range split point: {x}"
|
||||
if (
|
||||
max_num_batched_tokens is not None
|
||||
and x < max_num_batched_tokens
|
||||
and x > 1
|
||||
):
|
||||
computed_compile_ranges_split_points.append(x)
|
||||
compilation_config.compile_ranges_split_points = sorted(
|
||||
computed_compile_ranges_split_points
|
||||
)
|
||||
|
||||
def recalculate_max_model_len(self, max_model_len: int):
|
||||
# Can only be called in try_verify_and_update_config
|
||||
model_config = self.model_config
|
||||
|
||||
@ -15,6 +15,7 @@ import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.distributed import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
@ -407,15 +408,31 @@ class Worker(WorkerBase):
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
compile_sizes = self.vllm_config.compilation_config.compile_sizes
|
||||
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
|
||||
if not self.model_config.enforce_eager:
|
||||
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
if capture_sizes is not None:
|
||||
warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes]
|
||||
warmup_sizes = []
|
||||
|
||||
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
compile_sizes = self.vllm_config.compilation_config.compile_sizes
|
||||
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
|
||||
cg_capture_sizes: list[int] = []
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
cg_capture_sizes = [] if cg_sizes is None else cg_sizes
|
||||
warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]
|
||||
|
||||
compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
|
||||
# For each compile_range, if none of the batch sizes
|
||||
# in warmup_sizes or cudagraph_capture_sizes are in the range,
|
||||
# add the end of the range to ensure compilation/warmup.
|
||||
all_sizes = set(cg_capture_sizes)
|
||||
all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
|
||||
for compile_range in compile_ranges:
|
||||
if not any(x in compile_range for x in all_sizes):
|
||||
warmup_sizes.append(compile_range.end)
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
|
||||
@ -337,7 +337,7 @@ def is_residual_scattered_for_sp(
|
||||
The residual tensor is scattered across tensor parallel ranks when sequence
|
||||
parallelism and tensor parallelism is enabled.
|
||||
|
||||
This follows the same logic as SequenceParallelismPass.is_applicable():
|
||||
This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
|
||||
- In full-graph compilation mode (no splitting ops or using inductor graph
|
||||
partition), SP is always applied
|
||||
- Otherwise, SP is only applied for specific shapes in compile_sizes
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user