diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 134bade486079..c1f5d9658af16 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -2,15 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import os +import weakref +from contextlib import ExitStack import pytest +from tests.utils import wait_for_gpu_memory_to_clear from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform -MODEL = "Qwen/Qwen2-1.5B-Instruct" - @contextlib.contextmanager def temporary_environ(env_vars): @@ -31,64 +32,119 @@ def temporary_environ(env_vars): os.environ[k] = v -@pytest.fixture(scope="module") -def full_cudagraph_llm(): +@pytest.fixture(scope="class") +def llm_pair(request): + model = request.param + with temporary_environ({ "VLLM_USE_V1": "1", "VLLM_FLASH_ATTN_VERSION": "3" }): - return LLM(model=MODEL, - gpu_memory_utilization=0.3, - compilation_config=CompilationConfig(full_cuda_graph=True)) - - -@pytest.fixture(scope="module") -def piecewise_llm(): - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": "3" - }): - return LLM(model=MODEL, - gpu_memory_utilization=0.6, - compilation_config=CompilationConfig()) - - -def generate_text(llm: LLM, batch_size: int, max_tokens: int): - prompts = ["Hi my name is"] * batch_size - sampling_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens, - top_p=0.95) - - return llm.generate(prompts, sampling_params) + full = LLM( + model=model, + gpu_memory_utilization=0.45, + trust_remote_code=True, + max_model_len=1024, + compilation_config=CompilationConfig(full_cuda_graph=True), + ) + piecewise = LLM( + model=model, + gpu_memory_utilization=0.45, + trust_remote_code=True, + max_model_len=1024, + compilation_config=CompilationConfig(), + ) + + # PyTest caches the fixture values so we use weakref.proxy to enable GC + yield weakref.proxy(full), weakref.proxy(piecewise) + del full + del piecewise + + wait_for_gpu_memory_to_clear( + devices=[0], + threshold_ratio=0.1, + ) +@pytest.mark.parametrize( + "llm_pair", + [ + # Model names for the llm_pair fixture + "deepseek-ai/DeepSeek-V2-Lite", + "Qwen/Qwen2-1.5B-Instruct" + ], + indirect=True) @pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), - reason="Only Hopper GPUs support FlashAttention 3") -@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), - (16, 10), (25, 10), - (32, 10), (45, 10), - (64, 10), (8, 5), - (8, 20), (8, 200)]) -def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, - piecewise_llm): + reason="Only Hopper GPUs support FA3 and FlashMLA") +class TestFullCUDAGraph: """ - Load full cudagraph model and piecewise model once, and at the same time to - reuse them across various test cases. + Use a class such that an llm pair is constructed once for all + batch_size/max_tokens combinations and released immediately after. - Test various batch sizes and max_tokens to ensure that the full cudagraph - compilation works for padded cases too. + Module-scope fixtures would stick around the whole time, + meaning there would be multiple LLM instances hogging memory simultaneously. """ - piecewise_responses = generate_text(piecewise_llm, - batch_size=batch_size, - max_tokens=max_tokens) - full_cudagraph_responses = generate_text(full_cudagraph_llm, - batch_size=batch_size, - max_tokens=max_tokens) - # Check that all responses are the same - for i in range(len(piecewise_responses)): - assert piecewise_responses[i].outputs[ - 0].text == full_cudagraph_responses[i].outputs[0].text + @pytest.mark.parametrize(("batch_size", "max_tokens"), [ + (1, 10), + (7, 10), + (16, 10), + (25, 10), + (32, 10), + (45, 10), + (64, 10), + (123, 10), + (8, 5), + (8, 30), + ]) + def test_full_cudagraph(self, batch_size, max_tokens, + llm_pair: tuple[LLM, LLM]): + """ + Test various batch sizes and max_tokens to ensure that the + full cudagraph compilation works for padded cases too. + """ + + piecewise_llm, full_cudagraph_llm = llm_pair + + prompts = ["Hello, my name is"] * batch_size + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens, + top_p=0.95) + + piecewise_responses = piecewise_llm.generate(prompts, sampling_params) + full_responses = full_cudagraph_llm.generate(prompts, sampling_params) + + # Check that all responses are the same + for piecewise_res, full_res in zip(piecewise_responses, + full_responses): + assert piecewise_res.outputs[0].text == full_res.outputs[0].text + + +@pytest.mark.parametrize( + "model, supported", + [ + ("Qwen/Qwen2-1.5B-Instruct", True), + # MLA does not support capturing CUDA Graphs with size > max_num_seqs + ("deepseek-ai/DeepSeek-V2-Lite", False), + ]) +@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), + reason="Only Hopper GPUs support FA3 and FlashMLA") +def test_lower_max_num_seqs(model, supported): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": "3" + }), ExitStack() as stack: + if not supported: + stack.enter_context(pytest.raises(RuntimeError)) + + llm = LLM(model=model, + max_num_seqs=256, + trust_remote_code=True, + max_model_len=1024, + compilation_config=CompilationConfig( + full_cuda_graph=True, + cudagraph_capture_sizes=[64, 256, 512])) + llm.generate(["Hello, my name is"] * 10) def test_full_cudagraph_with_invalid_backend(): @@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend(): "VLLM_FLASH_ATTN_VERSION": "2" #FA2 not supported with full_cuda_graph }), pytest.raises(RuntimeError): - LLM(model=MODEL, + LLM(model="Qwen/Qwen2-1.5B-Instruct", compilation_config=CompilationConfig(full_cuda_graph=True)) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 9633f139873c4..06ac3527e1fb8 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -4,7 +4,7 @@ Test the piecewise compilation with a simple model so that we can exactly calculate the expected output and side effects. """ - +import pytest import torch from torch import nn from torch.library import Library @@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 +from vllm.forward_context import set_forward_context from vllm.utils import direct_register_custom_op global_counter = 0 @@ -76,7 +77,8 @@ class SillyModel(nn.Module): return x -def _test_simple_piecewise_compile(*, use_inductor): +@pytest.mark.parametrize("use_inductor", [True, False]) +def test_simple_piecewise_compile(use_inductor): assert VLLM_USE_V1 vllm_config = VllmConfig(compilation_config=CompilationConfig( @@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor): num_backend_compilations=3, # num_piecewise_capturable_graphs_seen num_cudagraph_captured= 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ): + ), set_forward_context({}, vllm_config=vllm_config): model(inputs) @@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor): output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) - - -def test_simple_piecewise_compile_inductor(): - _test_simple_piecewise_compile(use_inductor=True) - - -def test_simple_piecewise_compile_no_inductor(): - _test_simple_piecewise_compile(use_inductor=False) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 410c0101c99b9..b7ed8353b3cef 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -11,6 +11,7 @@ initialized randomly with a fixed seed. from dataclasses import dataclass from typing import Any, Optional +import pytest import torch from torch import nn from torch.library import Library @@ -19,6 +20,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, set_current_vllm_config) +from vllm.forward_context import set_forward_context from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -285,29 +287,32 @@ def run_model(llama_config, vllm_config=vllm_config, prefix="").eval().cuda() - B = 16 # max batch size - input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() - positions = torch.arange(B).cuda() + with set_forward_context({}, vllm_config=vllm_config): + B = 16 # max batch size + input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + positions = torch.arange(B).cuda() - model(input_ids, positions) - model(input_ids[:2], positions[:2]) - model(input_ids[:1], positions[:1]) + model(input_ids, positions) + model(input_ids[:2], positions[:2]) + model(input_ids[:1], positions[:1]) - input_ids[:2].zero_() - output = model(input_ids[:2], positions[:2]) + input_ids[:2].zero_() + output = model(input_ids[:2], positions[:2]) - output = output.cpu() + output = output.cpu() - if llama_config.tractable_init: - expected_output = tractable_computation(input_ids[:2], positions[:2], - llama_config).cpu() + if llama_config.tractable_init: + expected_output = tractable_computation(input_ids[:2], + positions[:2], + llama_config).cpu() - assert torch.allclose(output, expected_output) - else: - return output.cpu() + assert torch.allclose(output, expected_output) + else: + return output.cpu() -def _test_toy_llama(*, use_inductor): +@pytest.mark.parametrize("use_inductor", [True, False]) +def test_toy_llama(use_inductor: bool): # compare output with and without piecewise compilation llama_config = LlamaConfig(hidden_size=128, @@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor): assert torch.allclose(outputs[0], outputs[i]) -def test_toy_llama_inductor(): - _test_toy_llama(use_inductor=True) - - -def test_toy_no_inductor(): - _test_toy_llama(use_inductor=False) - - @torch.inference_mode def benchmark(): from triton.testing import do_bench diff --git a/tests/utils.py b/tests/utils.py index ade28a481261c..a37872830dade 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -667,42 +667,54 @@ def get_physical_device_indices(devices): @_nvml() -def wait_for_gpu_memory_to_clear(devices: list[int], - threshold_bytes: int, +def wait_for_gpu_memory_to_clear(*, + devices: list[int], + threshold_bytes: Optional[int] = None, + threshold_ratio: Optional[float] = None, timeout_s: float = 120) -> None: + assert threshold_bytes is not None or threshold_ratio is not None # Use nvml instead of pytorch to reduce measurement error from torch cuda # context. devices = get_physical_device_indices(devices) start_time = time.time() while True: output: dict[int, str] = {} - output_raw: dict[int, float] = {} + output_raw: dict[int, tuple[float, float]] = {} for device in devices: if current_platform.is_rocm(): dev_handle = amdsmi_get_processor_handles()[device] mem_info = amdsmi_get_gpu_vram_usage(dev_handle) gb_used = mem_info["vram_used"] / 2**10 + gb_total = mem_info["vram_total"] / 2**10 else: dev_handle = nvmlDeviceGetHandleByIndex(device) mem_info = nvmlDeviceGetMemoryInfo(dev_handle) gb_used = mem_info.used / 2**30 - output_raw[device] = gb_used - output[device] = f'{gb_used:.02f}' + gb_total = mem_info.total / 2**30 + output_raw[device] = (gb_used, gb_total) + output[device] = f'{gb_used:.02f}/{gb_total:.02f}' - print('gpu memory used (GB): ', end='') + print('gpu memory used/total (GiB): ', end='') for k, v in output.items(): print(f'{k}={v}; ', end='') print('') + if threshold_bytes is not None: + is_free = lambda used, total: used <= threshold_bytes / 2**30 + threshold = f"{threshold_bytes/2**30} GiB" + else: + is_free = lambda used, total: used / total <= threshold_ratio + threshold = f"{threshold_ratio:.2f}" + dur_s = time.time() - start_time - if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): + if all(is_free(used, total) for used, total in output_raw.values()): print(f'Done waiting for free GPU memory on devices {devices=} ' - f'({threshold_bytes/2**30=}) {dur_s=:.02f}') + f'({threshold=}) {dur_s=:.02f}') break if dur_s >= timeout_s: raise ValueError(f'Memory of devices {devices=} not free after ' - f'{dur_s=:.02f} ({threshold_bytes/2**30=})') + f'{dur_s=:.02f} ({threshold=})') time.sleep(5) diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index 993def49af700..8c49ea6cc1074 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -14,6 +14,7 @@ from vllm.compilation.backends import VllmBackend from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import weak_ref_tensors @@ -137,7 +138,10 @@ class CUDAPiecewiseBackend: if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() - if not entry.use_cudagraph: + # Skip CUDA graphs if this entry doesn't use them OR + # if we're supposed to skip them globally + skip_cuda_graphs = get_forward_context().skip_cuda_graphs + if not entry.use_cudagraph or skip_cuda_graphs: return entry.runnable(*args) if entry.cudagraph is None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6e3cb18fc5595..f3841d9d86bcb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -179,7 +179,8 @@ class LLM: hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, override_pooler_config: Optional[PoolerConfig] = None, - compilation_config: Optional[Union[int, dict[str, Any]]] = None, + compilation_config: Optional[Union[int, dict[str, Any], + CompilationConfig]] = None, **kwargs, ) -> None: """LLM constructor.""" diff --git a/vllm/forward_context.py b/vllm/forward_context.py index f3b0518a44e03..dd55b19feeaf6 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -94,6 +94,7 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None + skip_cuda_graphs: bool = False _forward_context: Optional[ForwardContext] = None @@ -108,11 +109,14 @@ def get_forward_context() -> ForwardContext: @contextmanager -def set_forward_context(attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None): +def set_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + skip_cuda_graphs: bool = False, +): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -135,7 +139,9 @@ def set_forward_context(attn_metadata: Any, static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, - dp_metadata=dp_metadata) + dp_metadata=dp_metadata, + skip_cuda_graphs=skip_cuda_graphs, + ) try: yield diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index d7a580c2883c3..1c4604cc27e47 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -7,7 +7,8 @@ from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl, TorchSDPAMetadata) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.ipex_attn import PagedAttention -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -53,7 +54,7 @@ class TorchSDPABackend: return False -class TorchSDPAMetadataBuilderV1: +class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec, block_table: BlockTable) -> None: @@ -118,9 +119,12 @@ class TorchSDPAMetadataBuilderV1: return True - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + runner = self.runner block_table = self.block_table seq_lens_np = runner.seq_lens_np[:num_reqs] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ebd9bd88dfd08..630ac13228f14 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional import numpy as np import torch @@ -21,13 +21,12 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner if current_platform.is_cuda(): @@ -306,7 +305,9 @@ def _get_sliding_window_configs( return sliding_window_configs -class FlashAttentionMetadataBuilder: +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -336,13 +337,14 @@ class FlashAttentionMetadataBuilder: # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return False + def build( + self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata + ) -> FlashAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -496,6 +498,11 @@ class FlashAttentionMetadataBuilder: ) return attn_metadata + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + # Full CUDA Graph always supported (FA2 support checked separately) + return True + def use_cascade_attention(self, *args, **kwargs) -> bool: return use_cascade_attention(*args, **kwargs) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 277fc3ea5db9b..12547b99e5b6e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -18,7 +18,8 @@ from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -202,7 +203,7 @@ class FlashInferMetadata: f" received {self.head_dim}.") -class FlashInferMetadataBuilder: +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -399,9 +400,11 @@ class FlashInferMetadataBuilder: kv_data_type=attn_metadata.data_type, ) - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 1588839b685e5..c8cb1481c8b46 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -15,7 +15,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -25,8 +26,6 @@ if current_platform.is_cuda(): logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner create_block_mask_compiled = torch.compile(create_block_mask, @@ -256,7 +255,8 @@ class FlexAttentionMetadata: self.block_mask = self.build_block_mask() -class FlexAttentionMetadataBuilder: +class FlexAttentionMetadataBuilder( + AttentionMetadataBuilder[FlexAttentionMetadata]): def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -272,13 +272,12 @@ class FlexAttentionMetadataBuilder: self.kv_cache_spec = kv_cache_spec self.block_table = block_table - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return False - - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -332,9 +331,6 @@ class FlexAttentionMetadataBuilder: ) return out - def use_cascade_attention(self, *args, **kwargs) -> bool: - return False - class FlexAttentionImpl(AttentionImpl): sliding_window: Optional[tuple[int, int]] diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 86e78d7894a11..1878ae74dbc6f 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -207,7 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, UnquantizedLinearMethod) from vllm.platforms import current_platform from vllm.utils import cdiv, round_down -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -329,7 +330,7 @@ class MLACommonMetadata(Generic[D]): M = TypeVar("M", bound=MLACommonMetadata) -class MLACommonMetadataBuilder(Generic[M]): +class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -450,9 +451,32 @@ class MLACommonMetadataBuilder(Generic[M]): seq_lens=seq_lens, ) - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with MLA. + """ + m = common_attn_metadata + assert m.num_reqs == m.num_actual_tokens, \ + "MLA only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + # Update state usually set in reorder_batch. + self._num_decodes = m.num_reqs + self._num_decode_tokens = m.num_actual_tokens + self._num_prefills = 0 + self._num_prefill_tokens = 0 + return self.build(0, m) + + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata) -> M: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this @@ -461,8 +485,11 @@ class MLACommonMetadataBuilder(Generic[M]): device = self.runner.device block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] - slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + slot_mapping = block_table.slot_mapping[:num_actual_tokens] query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -564,8 +591,9 @@ class MLACommonMetadataBuilder(Generic[M]): decode=decode_metadata, ) - def use_cascade_attention(self, *args, **kwargs) -> bool: - return False + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1 class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 318b8ede14366..be26e0060db5e 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, ClassVar, Optional import torch @@ -44,7 +44,7 @@ class FlashMLABackend(MLACommonBackend): @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): - tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor] + tile_scheduler_metadata: torch.Tensor num_splits: torch.Tensor @@ -54,14 +54,18 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + full_cudagraph_supported: ClassVar[bool] = True # Decode-only def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table) + super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) + self.cg_buf_tile_scheduler_metadata = None + self.cg_buf_num_splits = None + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ @@ -71,6 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): 1, # MQA for the decode path ) + if self.runner.full_cuda_graph: + # First time around (CUDAGraph capture), allocate the static buffer + if self.cg_buf_tile_scheduler_metadata is None: + self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata + self.cg_buf_num_splits = num_splits + else: + assert self.cg_buf_num_splits is not None + + # Metadata per-SM, fixed size (#SMs, TileMetadataSize) + assert (self.cg_buf_tile_scheduler_metadata.size() == + tile_scheduler_metadata.size()) + self.cg_buf_tile_scheduler_metadata.\ + copy_(tile_scheduler_metadata) + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + + # Num splits is per-batch, varying size (batch_size,) + n = num_splits.size(0) + # make sure static buffer is large enough + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] + num_splits_view.copy_(num_splits) + self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + num_splits = num_splits_view + return FlashMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 1f0406a7ac1f8..9fbca2e955e72 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -66,7 +66,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table) + super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata) assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 72c7643539273..8f6ecd532ccff 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1,15 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import abc +from abc import abstractmethod from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar +import numpy as np import torch +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + @dataclass class CommonAttentionMetadata: """ - Attention metadata attributes that can be shared by layers in different KV - cache groups and thus having different block table. + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. """ query_start_loc: torch.Tensor @@ -18,6 +26,67 @@ class CommonAttentionMetadata: """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" + num_reqs: int + """Number of requests""" + num_actual_tokens: int + """Total number of tokens in batch""" + max_query_len: int + """Longest query in batch""" + + +M = TypeVar("M") + + +class AttentionMetadataBuilder(abc.ABC, Generic[M]): + # Does this backend/builder support CUDA Graphs for attention. + full_cudagraph_supported: ClassVar[bool] = False + + @abstractmethod + def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: + """ + Central method that builds attention metadata. + Some builders (MLA) require reorder_batch to be called prior to build. + """ + raise NotImplementedError + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + """ + Can this batch (with given metadata) use CUDA Graphs for attention. + """ + return False + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + Build attention metadata for CUDA graph capture. Uses build by default. + Subclasses that override this method should call self.build or + super().build_for_cudagraph_capture. + """ + return self.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + + def use_cascade_attention( + self, + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + num_sms: int, + ) -> bool: + return False + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + """ + This method can reorder the batch if desired by the backend. + :return: Has the batch been reordered (default False). + """ + return False + def validate_kv_sharing_target(current_layer_name, target_layer_name, static_forward_context): diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7b550739a83da..153b67fe57147 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -138,15 +138,17 @@ class EagleProposer: max_query_len = query_lens.max().item() common_attn_metadata = CommonAttentionMetadata( - query_start_loc=cu_num_tokens, seq_lens=seq_lens) + query_start_loc=cu_num_tokens, + seq_lens=seq_lens, + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + ) assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups attn_metadata = self.runner.attn_metadata_builder.build( - num_reqs=batch_size, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e15daaac95a47..558325fa0347e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -16,10 +16,8 @@ from tqdm import tqdm import vllm.envs as envs from vllm.attention import AttentionType, get_attn_backend -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadataBuilder) +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention -from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -41,7 +39,8 @@ from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, is_pin_memory_available) -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, @@ -89,6 +88,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config @@ -197,7 +197,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): block_sizes=[self.cache_config.block_size], ) - self.use_cuda_graph = (self.vllm_config.compilation_config.level + self.use_cuda_graph = (self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -205,8 +205,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( - reversed( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) + reversed(self.compilation_config.cudagraph_capture_sizes)) + + self.full_cuda_graph = self.compilation_config.full_cuda_graph # Cache the device properties. self._init_device_properties() @@ -555,7 +556,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]: + ) -> tuple[dict[str, Any], bool, torch.Tensor, + Optional[SpecDecodeMetadata]]: + """ + :return: tuple[ + attn_metadata: layer-to-attention_metadata mapping, + attention_cuda_graphs: whether attention can run in cudagraph + logits_indices, spec_decode_metadata + ] + """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -669,7 +678,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): seq_lens = self.seq_lens[:num_reqs] common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, seq_lens=seq_lens) + query_start_loc=query_start_loc, + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + ) attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers @@ -679,25 +693,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 + builder = self.attn_metadata_builders[kv_cache_group_id] if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, scheduler_output. num_common_prefix_blocks[kv_cache_group_id], kv_cache_group_spec.kv_cache_spec, - self.attn_metadata_builders[kv_cache_group_id], + builder, ) - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id].build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) + attn_metadata_i = (builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + )) + for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i + attention_cuda_graphs = all( + b.can_run_in_cudagraph(common_attn_metadata) + for b in self.attn_metadata_builders) + use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -726,7 +743,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return attn_metadata, logits_indices, spec_decode_metadata + return (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata) def _compute_cascade_attn_prefix_len( self, @@ -1121,7 +1139,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size - enabled_sp = self.vllm_config.compilation_config.pass_config. \ + enabled_sp = self.compilation_config.pass_config. \ enable_sequence_parallelism if enabled_sp: # When sequence parallelism is enabled, we always pad num_tokens @@ -1189,8 +1207,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): return self.kv_connector_no_forward(scheduler_output) # Prepare the decoder inputs. - attn_metadata, logits_indices, spec_decode_metadata = ( - self._prepare_inputs(scheduler_output)) + (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1203,7 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.vllm_config.compilation_config.pass_config. \ + if self.compilation_config.pass_config. \ enable_sequence_parallelism and tp_size > 1: from vllm.utils import round_up num_input_tokens = round_up(num_scheduled_tokens, tp_size) @@ -1255,12 +1273,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) + # Some attention backends only support CUDA Graphs in pure decode. + # If attention doesn't support CUDA Graphs for this batch, but we + # compiled with full CUDA graphs, we have to skip them entirely. + skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs, + ): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -1769,7 +1795,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _dummy_run( self, num_tokens: int, - skip_attn: bool = True, + capture_attn_cudagraph: bool = False, ) -> torch.Tensor: # Padding for DP @@ -1790,9 +1816,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - if skip_attn: - attn_metadata: Optional[dict[str, Any]] = None - else: + attn_metadata: Optional[dict[str, Any]] = None + if capture_attn_cudagraph: + attn_metadata = {} + query_start_loc = self.query_start_loc[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. self.seq_lens_np[:num_reqs] = self.max_model_len @@ -1802,19 +1829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): seq_lens = self.seq_lens[:num_reqs] common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, seq_lens=seq_lens) + query_start_loc=query_start_loc, + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + ) - attn_metadata = {} for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id].build( - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - )) + + attn_metadata_i = self.attn_metadata_builders[ + kv_cache_group_id].build_for_cudagraph_capture( + common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -2039,14 +2066,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): - skip_attn = not self.vllm_config.compilation_config.full_cuda_graph + full_cg = self.full_cuda_graph for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), desc="Capturing CUDA graphs", total=len(self.cudagraph_batch_sizes)): - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): - self._dummy_run(num_tokens, skip_attn=skip_attn) - self._dummy_run(num_tokens, skip_attn=skip_attn) + for _ in range( + self.compilation_config.cudagraph_num_of_warmups): + self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) + self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -2089,20 +2116,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): "Non-Attention backend is not supported by V1 " "GPUModelRunner.") - if self.vllm_config.compilation_config.full_cuda_graph: - attn_backend_name = attn_backend_i.__name__ - flash_attn_version = get_flash_attn_version() - if attn_backend_name != "FlashAttentionBackend" or \ - flash_attn_version != 3: - raise ValueError( - f"full_cuda_graph is only supported with " - f"FA3. Current attention backend is " - f"{attn_backend_name}, FlashAttention version is " - f"{flash_attn_version}.") - block_table_i = self.input_batch.block_table[i] attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - weakref.proxy(self), kv_cache_spec, block_table_i) + weakref.proxy(self), + kv_cache_spec, + block_table_i, + ) + + if (self.full_cuda_graph + and not attn_metadata_builder_i.full_cudagraph_supported): + raise ValueError( + f"Full CUDAGraph not supported for " + f"{attn_backend_i.__name__}. Turn off CompilationConfig." + f"full_cuda_graph or use a different attention backend.") + self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) @@ -2142,9 +2169,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): to be reshaped to the desired shape before being used by the models. Args: - kv_cache_config: The KV cache config + kv_cache_config: The KV cache config Returns: - dict[str, torch.Tensor]: A map between layer names to their + dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} @@ -2171,11 +2198,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): Reshape the KV cache tensors to the desired shape and dtype. Args: - kv_cache_config: The KV cache config - kv_cache_raw_tensors: The KV cache buffer of each layer, with + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with correct size but uninitialized shape. Returns: - Dict[str, torch.Tensor]: A map between layer names to their + Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} @@ -2227,7 +2254,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): Args: kv_cache_config: The KV cache config Returns: - Dict[str, torch.Tensor]: A map between layer names to their + Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ # Initialize the memory buffer for KV cache @@ -2245,10 +2272,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): kv_caches, ) - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) return kv_caches def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: