diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 5d2786e122a6..75a81efedea3 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -298,10 +298,14 @@ def test_tp2_attn_quant_allreduce_rmsnorm( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(log_matches) == 2, log_holder.text + # 2 for each compile range + # (global compile range can be split due to fuse_allreduce_rmsnorm) + num_compile_ranges = len(compilation_config.get_compile_ranges()) + assert num_compile_ranges in [1, 2] - assert int(log_matches[0]) == matches.attention_fusion - assert int(log_matches[1]) == matches.attention_fusion + assert len(log_matches) == 2 * num_compile_ranges, log_holder.text + + assert all(int(log_match) == matches.attention_fusion for log_match in log_matches) log_matches = re.findall( r"collective_fusion.py:\d+] Replaced (\d+) patterns", @@ -312,6 +316,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm( assert int(log_matches[0]) == matches.allreduce_fusion assert int(log_matches[1]) == matches.allreduce_fusion + log_matches = re.findall( + r"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range", + log_holder.text, + ) + assert len(log_matches) == 2 * (num_compile_ranges - 1), log_holder.text + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( @@ -446,7 +456,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg # No cudagraphs by default if compilation_config.cudagraph_mode is None: compilation_config.cudagraph_mode = CUDAGraphMode.NONE - llm = LLM( model=model, compilation_config=compilation_config, @@ -459,3 +468,9 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Get the compile ranges split points after vllm config post init + # in order to compute compile ranges correctly + compilation_config.compile_ranges_split_points = ( + llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points + ) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py new file mode 100644 index 000000000000..d849a8617ebd --- /dev/null +++ b/tests/compile/test_compile_ranges.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +from torch import fx as fx +from torch import nn + +# This import automatically registers `torch.ops.silly.attention` +import tests.compile.silly_attention # noqa +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.inductor_pass import ( + InductorPass, + get_pass_context, +) +from vllm.config import ( + VllmConfig, + set_current_vllm_config, +) +from vllm.config.compilation import CompilationConfig, CompilationMode +from vllm.config.scheduler import SchedulerConfig +from vllm.config.utils import Range +from vllm.forward_context import set_forward_context + +BATCH_SIZE = 64 +MLP_SIZE = 128 + + +@support_torch_compile +class TestModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): + with set_forward_context({}, vllm_config=vllm_config): + model(torch.randn(BATCH_SIZE, MLP_SIZE)) + for batch_size in batch_sizes: + model(torch.randn(batch_size, MLP_SIZE)) + + +class PostGradRangeChecker(InductorPass): + def __init__(self, ranges: list[Range]): + self.ranges = ranges + self.num_calls = 0 + + def __call__(self, graph: fx.Graph): + compile_range = get_pass_context().compile_range + assert compile_range in self.ranges, ( + f"Compile range {compile_range} not in {self.ranges}" + ) + self.num_calls += 1 + + def uuid(self) -> str: + state: dict[str, Any] = {} + return InductorPass.hash_dict(state) + + +def test_compile_ranges(use_fresh_inductor_cache): + post_grad_range_checker = PostGradRangeChecker( + [ + Range(start=1, end=8), + Range(start=16, end=16), + Range(start=9, end=32), + Range(start=64, end=64), + Range(start=33, end=8192), + ] + ) + torch.set_default_device("cuda") + vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_split_points=[8, 32], + compile_sizes=[16, 64, 128], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_range_checker, + }, + ), + ) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix="").eval() + # Number of compilations: 3 for each compile range + 2 compile sizes + batch_sizes = [1, 4, 16, 24, 48, 64, 8192] + + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=5, + ): + run_model(vllm_config, model, batch_sizes) + assert post_grad_range_checker.num_calls == 5 + + +def test_compile_config_get_compile_ranges(): + compilation_config = CompilationConfig( + compile_ranges_split_points=[8, 32], + ) + VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), + compilation_config=compilation_config, + ) + assert compilation_config.get_compile_ranges() == [ + Range(start=1, end=8), + Range(start=9, end=32), + Range(start=33, end=8192), + ] + + +def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache): + # To force multiple compilations, we disable the compile cache + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + post_grad_range_checker = PostGradRangeChecker( + ranges=[ + Range(start=1, end=8), + Range(start=9, end=8192), + ] + ) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=8192, + ) + torch.set_default_device("cuda") + + def create_vllm_config(): + return VllmConfig( + scheduler_config=scheduler_config, + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_split_points=[8], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_range_checker, + }, + ), + ) + + vllm_config_1 = create_vllm_config() + with set_current_vllm_config(vllm_config_1): + model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval() + batch_sizes = [1, 16] + run_model(vllm_config_1, model1, batch_sizes) + assert post_grad_range_checker.num_calls == 2 + + post_grad_range_checker.num_calls = 0 + # Create a new vllm config with the new pass context + vllm_config_2 = create_vllm_config() + with set_current_vllm_config(vllm_config_2): + model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() + batch_sizes = [4, 32] + run_model(vllm_config_2, model2, batch_sizes) + # Check that cache is used, so the number of calls + # should be 0 + assert post_grad_range_checker.num_calls == 0 diff --git a/tests/conftest.py b/tests/conftest.py index 204452b5835c..0d456fb364cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,6 +67,9 @@ from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.collection_utils import is_list_of from vllm.utils.torch_utils import set_default_torch_num_threads +from torch._inductor.utils import fresh_cache + + if TYPE_CHECKING: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.generation.utils import GenerateOutput @@ -1465,3 +1468,14 @@ def clean_gpu_memory_between_tests(): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + + +@pytest.fixture +def use_fresh_inductor_cache(): + """ + Use a fresh inductor cache for the test. + This is useful to ensure that the test is not affected by the + previous test calls. + """ + with fresh_cache(): + yield diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index b5b7fe2b76c2..26f4f16a845a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -26,7 +26,7 @@ from vllm.compilation.partition_rules import ( should_split, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig -from vllm.config.utils import hash_factors +from vllm.config.utils import Range, hash_factors from vllm.logger import init_logger from vllm.logging_utils import lazy from vllm.platforms import current_platform @@ -90,7 +90,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[int | None, int, str], Any] = dict() + self.cache: dict[tuple[Range, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -99,11 +99,11 @@ class CompilerManager: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, runtime_shape: int | None = None): + def compile_context(self, compile_range: Range): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" - with pass_context(runtime_shape): + with pass_context(compile_range): if self.compilation_config.use_inductor_graph_partition: with inductor_partition_rule_context( self.compilation_config.splitting_ops @@ -159,29 +159,21 @@ class CompilerManager: graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range, ) -> Callable | None: - if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + handle = self.cache[(compile_range, graph_index, self.compiler.name)] compiled_graph = self.compiler.load( - handle, graph, example_inputs, graph_index, runtime_shape + handle, graph, example_inputs, graph_index, compile_range + ) + logger.debug( + "Directly load the %s-th graph for compile range %sfrom %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, ) - if runtime_shape is None: - logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via handle %s", - graph_index, - self.compiler.name, - handle, - ) - else: - logger.debug( - "Directly load the %s-th graph for shape %s from %s via handle %s", - graph_index, - str(runtime_shape), - self.compiler.name, - handle, - ) return compiled_graph def compile( @@ -190,9 +182,9 @@ class CompilerManager: example_inputs, additional_inductor_config, compilation_config: CompilationConfig, + compile_range: Range, graph_index: int = 0, num_graphs: int = 1, - runtime_shape: int | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -204,7 +196,7 @@ class CompilerManager: compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -212,19 +204,12 @@ class CompilerManager: now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: - logger.info( - "Directly load the compiled graph(s) for dynamic shape " - "from the cache, took %.3f s", - elapsed, - ) - else: - logger.info( - "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", - str(runtime_shape), - elapsed, - ) + logger.info( + "Directly load the compiled graph(s) for compile range %s " + "from the cache, took %.3f s", + str(compile_range), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -233,14 +218,15 @@ class CompilerManager: # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - - with self.compile_context(runtime_shape): + maybe_key = "artifact_compile_range_" + maybe_key += f"{compile_range.start}_{compile_range.end}" + maybe_key += f"_subgraph_{graph_index}" + with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( graph, example_inputs, additional_inductor_config, - runtime_shape, + compile_range, maybe_key, ) @@ -248,55 +234,34 @@ class CompilerManager: # store the artifact in the cache if is_compile_cache_enabled(additional_inductor_config) and handle is not None: - self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if runtime_shape is None: - logger.info_once( - "Cache the graph for dynamic shape for later use", scope="local" - ) - else: - logger.info_once( - "Cache the graph of shape %s for later use", - str(runtime_shape), - scope="local", - ) - if runtime_shape is None: - logger.debug( - "Store the %s-th graph for dynamic shape from %s via handle %s", - graph_index, - self.compiler.name, - handle, - ) - else: - logger.debug( - "Store the %s-th graph for shape %s from %s via handle %s", - graph_index, - str(runtime_shape), - self.compiler.name, - handle, + logger.info_once( + "Cache the graph of compile range %s for later use", + str(compile_range), ) + logger.debug( + "Store the %s-th graph for compile range%s from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: - logger.info_once( - "Compiling a graph for dynamic shape takes %.2f s", - elapsed, - scope="local", - ) - else: - logger.info_once( - "Compiling a graph for shape %s takes %.2f s", - runtime_shape, - elapsed, - scope="local", - ) + logger.info_once( + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), + elapsed, + scope="local", + ) return compiled_graph @@ -427,19 +392,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - global compilation_start_time - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.vllm_backend.inductor_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None, - ) - ) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -449,7 +402,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, self.vllm_backend, ) @@ -589,8 +541,13 @@ class VllmBackend: ) else: # Config should automatically wrap all inductor passes - assert isinstance(self.inductor_config[self.pass_key], InductorPass) - self.pass_manager.add(self.inductor_config[self.pass_key]) + assert isinstance( + self.compilation_config.inductor_compile_config[self.pass_key], + InductorPass, + ) + self.pass_manager.add( + self.compilation_config.inductor_compile_config[self.pass_key] + ) self.inductor_config[self.pass_key] = self.pass_manager def __call__( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 69d4606d73eb..2717738dd7c2 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -431,7 +432,7 @@ class AsyncTPPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -441,7 +442,7 @@ class AsyncTPPass(VllmPatternMatcherPass): ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range.is_single_size() and compile_range.end % tp_size == 0 @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -505,91 +506,60 @@ if flashinfer_comm is not None: num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size + max_tensor_size = max_token_num * hidden_size * element_size + assert current_tensor_size <= max_tensor_size, ( + f"Current tensor size {current_tensor_size} is larger than " + f"max token num {max_token_num} * hidden size {hidden_size} * " + f"element size {element_size}" + ) + device_capability = current_platform.get_device_capability().to_int() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( + device_capability, {} + ).get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB + ) - if num_tokens <= max_token_num: - device_capability = current_platform.get_device_capability().to_int() - # Get one shot input size limit for the current world size - # for the current device capability - max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( - device_capability, {} - ).get(world_size, None) - # Use one shot if no max size for one shot is specified - use_oneshot = ( - max_one_shot_size_mb is None - or current_tensor_size <= max_one_shot_size_mb * MiB - ) - - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) - if norm_out is None: - norm_out = allreduce_in - residual_out = residual - else: - # return residual_out as allreduce_out with zeroed residual_in - # as flashinfer does not support rms_norm - # and allreduce_out together - residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, - quant_out=quant_out, - scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, - scale_factor=scale_factor, - ) + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) + if norm_out is None: + norm_out = allreduce_in + residual_out = residual else: - allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None: - # Do fused rms norm static fp8 quant fused op - if norm_out is None: - torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, - allreduce_out, - residual, - rms_gamma, - scale_factor, - rms_eps, - ) - else: - torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps - ) - else: - if norm_out is None: - torch.ops._C.fused_add_rms_norm( - allreduce_out, residual, rms_gamma, rms_eps - ) - norm_out = allreduce_out - else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None and scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - if scale_factor is None or norm_out is not None: - # we need to return allreduce output - # in cases of non quant fused AR + RMS norm - # and fused AR + RMS norm + quant without fused add - allreduce_in.copy_(allreduce_out) + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) def call_trtllm_fused_allreduce_norm_fake( allreduce_in: torch.Tensor, @@ -1128,7 +1098,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass): if max_size is None: # Flashinfer doesn't support current world size logger.warning( - "Flashinfer allreduce fusion is not supported for world size %s", + "Flashinfer allreduce fusion is not supported for world size %s" + " or max size is not provided", self.tp_size, ) return @@ -1216,6 +1187,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass): self.disabled = False + def is_applicable_for_range(self, compile_range: Range) -> bool: + return compile_range.end <= self.max_token_num + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 7deaba1a99fa..ab56d3561c56 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -15,6 +15,7 @@ import torch.fx as fx import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.utils.hashing import safe_hash from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -63,16 +64,16 @@ class CompilerInterface: graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a runtime shape. If the `runtime_shape` is None, it means - the `example_inputs` have a dynamic shape. Otherwise, the - `runtime_shape` specifies the shape of the inputs. Right now we only - support one variable shape for all inputs, which is the batchsize - (number of tokens) during inference. + with a range. The `compile_range` specifies the range of the inputs, + it could be concrete size (if compile_sizes is provided), e.g. [4, 4] + or a range [5, 8]. + Right now we only support one variable in ranges for all inputs, + which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -98,7 +99,7 @@ class CompilerInterface: graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range, ) -> Callable: """ Load the compiled function from the handle. @@ -212,20 +213,20 @@ class InductorStandaloneAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(runtime_shape, int): + if compile_range.is_single_size(): dynamic_shapes = "from_example_inputs" else: - dynamic_shapes = "from_tracing_context" + dynamic_shapes = "from_graph" from torch._inductor import standalone_compile @@ -235,7 +236,6 @@ class InductorStandaloneAdaptor(CompilerInterface): dynamic_shapes=dynamic_shapes, options={"config_patches": current_config}, ) - # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) @@ -251,7 +251,7 @@ class InductorStandaloneAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -315,7 +315,7 @@ class InductorAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -329,7 +329,7 @@ class InductorAdaptor(CompilerInterface): current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() # inductor can inplace modify the graph, so we need to copy it @@ -512,7 +512,7 @@ class InductorAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -608,9 +608,9 @@ class InductorAdaptor(CompilerInterface): return contextlib.nullcontext() -def set_inductor_config(config, runtime_shape): - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters +def set_inductor_config(config, compile_range: Range): + if compile_range.is_single_size(): + # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( @@ -630,7 +630,7 @@ class EagerAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9af635a929b4..8159b817f637 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,6 +14,7 @@ import torch from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): @@ -28,8 +29,8 @@ _pass_context = None class PassContext: - def __init__(self, runtime_shape: int | None): - self.runtime_shape = runtime_shape + def __init__(self, compile_range: Range): + self.compile_range: Range = compile_range def get_pass_context() -> PassContext: @@ -39,13 +40,13 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: int | None): +def pass_context(compile_range: Range): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ global _pass_context prev_context = _pass_context - _pass_context = PassContext(runtime_shape) + _pass_context = PassContext(compile_range) try: yield finally: @@ -96,7 +97,7 @@ class InductorPass(CustomGraphPass): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable(self, shape: int | None): + def is_applicable_for_range(self, compile_range: Range): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 37f48721ea20..6848bfb6a3c5 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -24,7 +24,11 @@ if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass from .fix_functionalization import FixFunctionalizationPass -from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context +from .inductor_pass import ( + CustomGraphPass, + InductorPass, + get_pass_context, +) from .noop_elimination import NoOpEliminationPass logger = init_logger(__name__) @@ -70,13 +74,13 @@ class PostGradPassManager(CustomGraphPass): def __call__(self, graph: fx.Graph): VllmInductorPass.dump_prefix = 0 # reset dump index - shape = get_pass_context().runtime_shape + compile_range = get_pass_context().compile_range for pass_ in self.passes: - if pass_.is_applicable(shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) VllmInductorPass.dump_prefix += 1 else: - logger.debug("Skipping %s with shape %s", pass_, shape) + logger.debug("Skipping %s with compile range %s", pass_, compile_range) # post-cleanup goes before fix_functionalization # because it requires a functional graph @@ -133,4 +137,8 @@ class PostGradPassManager(CustomGraphPass): state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) + # Include the compile range in the uuid to ensure that inductor + # recompiles the graph for the new dynamic compile range. + state["compile_range"] = str(get_pass_context().compile_range) + return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index e535d2c461c6..129b9b5deea3 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -7,18 +7,18 @@ from typing import Any import torch.fx as fx -import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.logger import init_logger logger = init_logger(__name__) @dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int +class RangeEntry: + compile_range: Range compiled: bool = False runnable: Callable = None # type: ignore @@ -31,7 +31,6 @@ class PiecewiseBackend: piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, ): """ @@ -55,67 +54,111 @@ class PiecewiseBackend: self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) + self.compile_ranges = self.compilation_config.get_compile_ranges() + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) - self.first_run_finished = False - - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + self.compile_sizes = self.compilation_config.compile_sizes + log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + logger.debug_once(log_string) self.sym_shape_indices = sym_shape_indices - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + # the entries for ranges that we need to either + self.range_entries: dict[Range, RangeEntry] = {} - # the entries for different shapes that we need to compile - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} - - # to_be_compiled_sizes tracks the remaining sizes to compile, + # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for shape in self.compile_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - runnable=self.compiled_graph_for_general_shape, + for size in self.compile_sizes: + range = Range(start=size, end=size) + if range not in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + self.to_be_compiled_ranges.add(range) + + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, ) def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: + if self.is_last_graph and not self.to_be_compiled_ranges: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) + def _fakify_args(self, args: list[Any]) -> list[Any]: + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # torch.compile probably should not accept + # non fake tensors as example inputs! + # See issue https://github.com/vllm-project/vllm/issues/27899 + fake_example_inputs = [] + for node in self.graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(args) + return fake_example_inputs - runtime_shape = args[self.sym_shape_indices[0]] + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: + if not range_entry.compiled: + range_entry.compiled = True + self.to_be_compiled_ranges.remove(range_entry.compile_range) - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - - entry = self.concrete_size_entries[runtime_shape] - - if not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( + # fakify for range, real args for concrete size. + # For concrete size, we clear the shape env in + # compiler_manager.compile() so no need to fakify. + args = ( + self._fakify_args(args) + if not range_entry.compile_range.is_single_size() + else args + ) + range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.vllm_backend.inductor_config, self.compilation_config, + compile_range=range_entry.compile_range, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, ) - # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() + self.check_for_ending_compilation() - return entry.runnable(*args) + def _find_range_for_shape(self, runtime_shape: int) -> Range | None: + # First we try to find the range entry for the concrete compile size + # If not found, we search for the range entry + # that contains the runtime shape. + if runtime_shape in self.compile_sizes: + return self.range_entries[Range(start=runtime_shape, end=runtime_shape)] + else: + for range in self.compile_ranges: + if runtime_shape in range: + return self.range_entries[range] + return None + + def __call__(self, *args) -> Any: + runtime_shape = args[self.sym_shape_indices[0]] + range_entry = self._find_range_for_shape(runtime_shape) + + assert range_entry is not None, ( + f"Shape out of considered range: {runtime_shape} " + "[1, max_num_batched_tokens]" + ) + + self._maybe_compile_for_range_entry(range_entry, args) + return range_entry.runnable(*args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index cf4b8118f6b5..a4046356bcda 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -9,6 +9,7 @@ import torch.fx as fx from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index d3d50e6ae7b2..5f9e2cfdd789 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -13,7 +13,13 @@ from pydantic.dataclasses import dataclass import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors +from vllm.config.utils import ( + Range, + config, + get_hash_factors, + handle_deprecated, + hash_factors, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -173,6 +179,9 @@ class PassConfig: """ MiB = 1024 * 1024 + FI_SUPPORTED_WORLD_SIZES = [2, 4, 8] + if world_size not in FI_SUPPORTED_WORLD_SIZES: + return None max_size_mb = self.fi_allreduce_fusion_max_size_mb if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) @@ -379,6 +388,8 @@ class CompilationConfig: [vllm.config.CompilationConfig.cudagraph_copy_inputs] - Inductor compilation: - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`compile_ranges_split_points`] + [vllm.config.CompilationConfig.compile_ranges_split_points] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -492,6 +503,21 @@ class CompilationConfig: to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" + compile_ranges_split_points: list[int] | None = None + """Split points that represent compile ranges for inductor. + The compile ranges are + [1, split_points[0]], + [split_points[0] + 1, split_points[1]], ..., + [split_points[-1] + 1, max_num_batched_tokens]. + Compile sizes are also used single element ranges, + the range is represented as [compile_sizes[i], compile_sizes[i]]. + + If a range overlaps with the compile size, graph for compile size + will be prioritized, i.e. if we have a range [1, 8] and a compile size 4, + graph for compile size 4 will be compiled and used instead of the graph + for range [1, 8]. + """ + inductor_compile_config: dict = field(default_factory=dict) """Additional configurations for inductor. - None: use default configurations.""" @@ -1153,3 +1179,13 @@ class CompilationConfig: self.bs_to_padded_graph_size[bs] = start else: self.bs_to_padded_graph_size[bs] = end + + def get_compile_ranges(self) -> list[Range]: + """Get the compile ranges for the compilation config.""" + if self.compile_ranges_split_points is None: + return [] + split_points = sorted(set(self.compile_ranges_split_points)) + return [ + Range(start=s + 1, end=e) + for s, e in zip([0] + split_points[:-1], split_points) + ] diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 3124fcf00739..93da3fd417ac 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -10,7 +10,7 @@ import json import pathlib import textwrap from collections.abc import Iterable, Mapping, Sequence, Set -from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -322,3 +322,35 @@ def handle_deprecated( for new_name in new_names: setattr(config, new_name, old_val) + + +@dataclass +class Range: + """ + A range of numbers. + Inclusive of start, inclusive of end. + """ + + start: int + end: int + + def is_single_size(self) -> bool: + return self.start == self.end + + def __contains__(self, size: int) -> bool: + # Inclusive of start, inclusive of end + return self.start <= size <= self.end + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Range): + return False + return self.start == other.start and self.end == other.end + + def __hash__(self) -> int: + return hash((self.start, self.end)) + + def __str__(self) -> str: + return f"({self.start}, {self.end})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ce3d3b20865d..47e7ffded797 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -725,6 +725,8 @@ class VllmConfig: "--kv-sharing-fast-prefill requires changes on model side for " "correctness and to realize prefill savings. " ) + # TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands + self._set_compile_ranges() if self.model_config and self.model_config.is_encoder_decoder: from vllm.multimodal import MULTIMODAL_REGISTRY @@ -1126,6 +1128,52 @@ class VllmConfig: # complete the remaining process. self.compilation_config.post_init_cudagraph_sizes() + def _set_compile_ranges(self): + """ + Set the compile ranges for the compilation config. + """ + compilation_config = self.compilation_config + computed_compile_ranges_split_points = [] + + # The upper bound of the compile ranges is the max_num_batched_tokens + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if max_num_batched_tokens is not None: + computed_compile_ranges_split_points.append(max_num_batched_tokens) + + # Add the compile ranges for flashinfer + if compilation_config.pass_config.fuse_allreduce_rms: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() + * self.model_config.dtype.itemsize + ) + if ( + max_num_batched_tokens is not None + and max_token_num < max_num_batched_tokens + ): + computed_compile_ranges_split_points.append(max_token_num) + else: + logger.debug( + "Max num batched tokens below allreduce-rms fusion threshold, " + "allreduce-rms fusion will be enabled for all num_tokens." + ) + + if compilation_config.compile_ranges_split_points is not None: + for x in compilation_config.compile_ranges_split_points: + assert isinstance(x, int) + assert x > 0, f"Invalid compile range split point: {x}" + if ( + max_num_batched_tokens is not None + and x < max_num_batched_tokens + and x > 1 + ): + computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( + computed_compile_ranges_split_points + ) + def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config model_config = self.model_config diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d189d0860b9a..a46ec2bd118f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,6 +15,7 @@ import torch.nn as nn import vllm.envs as envs from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.compilation import CompilationMode from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -407,15 +408,31 @@ class Worker(WorkerBase): self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: - # warm up sizes that are not in cudagraph capture sizes, - # but users still want to compile for better performance, - # e.g. for the max-num-batched token size in chunked prefill. - compile_sizes = self.vllm_config.compilation_config.compile_sizes - warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] - if not self.model_config.enforce_eager: - capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - if capture_sizes is not None: - warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes] + warmup_sizes = [] + + if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + compile_sizes = self.vllm_config.compilation_config.compile_sizes + warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] + cg_capture_sizes: list[int] = [] + + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_capture_sizes = [] if cg_sizes is None else cg_sizes + warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes] + + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the end of the range to ensure compilation/warmup. + all_sizes = set(cg_capture_sizes) + all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) + for compile_range in compile_ranges: + if not any(x in compile_range for x in all_sizes): + warmup_sizes.append(compile_range.end) + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 427a0d296b25..0b0e2006d73d 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -337,7 +337,7 @@ def is_residual_scattered_for_sp( The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled. - This follows the same logic as SequenceParallelismPass.is_applicable(): + This follows the same logic as SequenceParallelismPass.is_applicable_for_range(): - In full-graph compilation mode (no splitting ops or using inductor graph partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes