diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 52e0fcc2881f..37d8ae0c08bf 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -1,14 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -import torch import vllm from vllm.compilation.counter import compilation_counter -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) - -from .piecewise.test_simple import SillyModel +from vllm.config import VllmConfig def test_use_cudagraphs_dynamic(monkeypatch): @@ -22,23 +18,24 @@ def test_use_cudagraphs_dynamic(monkeypatch): @pytest.mark.parametrize("enabled", [True, False]) -def test_use_cudagraphs(enabled): +def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): assert vllm.envs.VLLM_USE_V1 - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=enabled, - cudagraph_capture_sizes=[100], - )) - with set_current_vllm_config(vllm_config): - model = SillyModel(vllm_config=vllm_config, prefix='') - inputs = torch.randn(100, device="cuda") + # Disable multiprocessing so that the counter is in the same process + monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') - with compilation_counter.expect( - num_graphs_seen=1, # one graph for the model - num_cudagraph_captured=1 if enabled else 0, - ): - # first run is warmup - model(inputs) - # second run does CUDAGraphs recording (if enabled) - model(inputs) + compilation_config = { + "cudagraph_capture_sizes": [100], + "use_cudagraph": enabled, + } + with ( + compilation_counter.expect( + num_graphs_seen=1, + num_gpu_runner_capture_triggers=1 if enabled else 0, + num_cudagraph_captured=13 if enabled else 0, + ), + # loading the model causes compilation (if enabled) to happen + vllm_runner('facebook/opt-125m', + compilation_config=compilation_config, + gpu_memory_utilization=0.4) as _): + pass diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 165347cfccef..9d7a25689b56 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -15,6 +15,9 @@ class CompilationCounter: # not including the splitting ops num_piecewise_capturable_graphs_seen: int = 0 num_backend_compilations: int = 0 + # Number of gpu_model_runner attempts to trigger CUDAGraphs capture + num_gpu_runner_capture_triggers: int = 0 + # Number of CUDAGraphs captured num_cudagraph_captured: int = 0 # InductorAdapter.compile calls num_inductor_compiles: int = 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 516a90e481ed..c4163eb2b8f5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -18,6 +18,7 @@ import vllm.envs as envs from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention +from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -200,9 +201,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): block_sizes=[self.cache_config.block_size], ) - self.use_cuda_graph = (self.compilation_config.level - == CompilationLevel.PIECEWISE - and not self.model_config.enforce_eager) + self.use_cuda_graph = ( + self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and self.vllm_config.compilation_config.use_cudagraph + and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -2058,10 +2061,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( - "Skipping CUDA graph capture. Please add " - "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) + "Skipping CUDA graph capture. To turn on CUDA graph capture, " + "set -O %s and ensure `use_cudagraph` was not manually set to " + "False", CompilationLevel.PIECEWISE) return + compilation_counter.num_gpu_runner_capture_triggers += 1 + start_time = time.perf_counter() start_free_gpu_memory = torch.cuda.mem_get_info()[0]