[CUDA] Enable full cudagraph for FlashMLA (#18581)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič 2025-06-13 14:12:26 -04:00 committed by GitHub
parent 1015296b79
commit 3597b06a4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 452 additions and 219 deletions

View File

@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import os import os
import weakref
from contextlib import ExitStack
import pytest import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
MODEL = "Qwen/Qwen2-1.5B-Instruct"
@contextlib.contextmanager @contextlib.contextmanager
def temporary_environ(env_vars): def temporary_environ(env_vars):
@ -31,64 +32,119 @@ def temporary_environ(env_vars):
os.environ[k] = v os.environ[k] = v
@pytest.fixture(scope="module") @pytest.fixture(scope="class")
def full_cudagraph_llm(): def llm_pair(request):
model = request.param
with temporary_environ({ with temporary_environ({
"VLLM_USE_V1": "1", "VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3" "VLLM_FLASH_ATTN_VERSION": "3"
}): }):
return LLM(model=MODEL, full = LLM(
gpu_memory_utilization=0.3, model=model,
compilation_config=CompilationConfig(full_cuda_graph=True)) gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
@pytest.fixture(scope="module") compilation_config=CompilationConfig(full_cuda_graph=True),
def piecewise_llm(): )
with temporary_environ({ piecewise = LLM(
"VLLM_USE_V1": "1", model=model,
"VLLM_FLASH_ATTN_VERSION": "3" gpu_memory_utilization=0.45,
}): trust_remote_code=True,
return LLM(model=MODEL, max_model_len=1024,
gpu_memory_utilization=0.6, compilation_config=CompilationConfig(),
compilation_config=CompilationConfig()) )
# PyTest caches the fixture values so we use weakref.proxy to enable GC
def generate_text(llm: LLM, batch_size: int, max_tokens: int): yield weakref.proxy(full), weakref.proxy(piecewise)
prompts = ["Hi my name is"] * batch_size del full
sampling_params = SamplingParams(temperature=0.0, del piecewise
max_tokens=max_tokens,
top_p=0.95) wait_for_gpu_memory_to_clear(
devices=[0],
return llm.generate(prompts, sampling_params) threshold_ratio=0.1,
)
@pytest.mark.parametrize(
"llm_pair",
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite",
"Qwen/Qwen2-1.5B-Instruct"
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), @pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FlashAttention 3") reason="Only Hopper GPUs support FA3 and FlashMLA")
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), class TestFullCUDAGraph:
(16, 10), (25, 10),
(32, 10), (45, 10),
(64, 10), (8, 5),
(8, 20), (8, 200)])
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
piecewise_llm):
""" """
Load full cudagraph model and piecewise model once, and at the same time to Use a class such that an llm pair is constructed once for all
reuse them across various test cases. batch_size/max_tokens combinations and released immediately after.
Test various batch sizes and max_tokens to ensure that the full cudagraph Module-scope fixtures would stick around the whole time,
compilation works for padded cases too. meaning there would be multiple LLM instances hogging memory simultaneously.
""" """
piecewise_responses = generate_text(piecewise_llm,
batch_size=batch_size,
max_tokens=max_tokens)
full_cudagraph_responses = generate_text(full_cudagraph_llm,
batch_size=batch_size,
max_tokens=max_tokens)
# Check that all responses are the same @pytest.mark.parametrize(("batch_size", "max_tokens"), [
for i in range(len(piecewise_responses)): (1, 10),
assert piecewise_responses[i].outputs[ (7, 10),
0].text == full_cudagraph_responses[i].outputs[0].text (16, 10),
(25, 10),
(32, 10),
(45, 10),
(64, 10),
(123, 10),
(8, 5),
(8, 30),
])
def test_full_cudagraph(self, batch_size, max_tokens,
llm_pair: tuple[LLM, LLM]):
"""
Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too.
"""
piecewise_llm, full_cudagraph_llm = llm_pair
prompts = ["Hello, my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
# Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
@pytest.mark.parametrize(
"model, supported",
[
("Qwen/Qwen2-1.5B-Instruct", True),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
("deepseek-ai/DeepSeek-V2-Lite", False),
])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(RuntimeError))
llm = LLM(model=model,
max_num_seqs=256,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[64, 256, 512]))
llm.generate(["Hello, my name is"] * 10)
def test_full_cudagraph_with_invalid_backend(): def test_full_cudagraph_with_invalid_backend():
@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend():
"VLLM_FLASH_ATTN_VERSION": "VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph "2" #FA2 not supported with full_cuda_graph
}), pytest.raises(RuntimeError): }), pytest.raises(RuntimeError):
LLM(model=MODEL, LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(full_cuda_graph=True)) compilation_config=CompilationConfig(full_cuda_graph=True))

View File

@ -4,7 +4,7 @@
Test the piecewise compilation with a simple model so that we Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects. can exactly calculate the expected output and side effects.
""" """
import pytest
import torch import torch
from torch import nn from torch import nn
from torch.library import Library from torch.library import Library
@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config) set_current_vllm_config)
from vllm.envs import VLLM_USE_V1 from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
global_counter = 0 global_counter = 0
@ -76,7 +77,8 @@ class SillyModel(nn.Module):
return x return x
def _test_simple_piecewise_compile(*, use_inductor): @pytest.mark.parametrize("use_inductor", [True, False])
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1 assert VLLM_USE_V1
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(compilation_config=CompilationConfig(
@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured= num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ), set_forward_context({}, vllm_config=vllm_config):
model(inputs) model(inputs)
@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
output = model(input) output = model(input)
assert global_counter == 2 assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
def test_simple_piecewise_compile_inductor():
_test_simple_piecewise_compile(use_inductor=True)
def test_simple_piecewise_compile_no_inductor():
_test_simple_piecewise_compile(use_inductor=False)

View File

@ -11,6 +11,7 @@ initialized randomly with a fixed seed.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
import pytest
import torch import torch
from torch import nn from torch import nn
from torch.library import Library from torch.library import Library
@ -19,6 +20,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config) set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
# create a library to hold the custom op # create a library to hold the custom op
@ -285,29 +287,32 @@ def run_model(llama_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="").eval().cuda() prefix="").eval().cuda()
B = 16 # max batch size with set_forward_context({}, vllm_config=vllm_config):
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() B = 16 # max batch size
positions = torch.arange(B).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()
model(input_ids, positions) model(input_ids, positions)
model(input_ids[:2], positions[:2]) model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1]) model(input_ids[:1], positions[:1])
input_ids[:2].zero_() input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2]) output = model(input_ids[:2], positions[:2])
output = output.cpu() output = output.cpu()
if llama_config.tractable_init: if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2], positions[:2], expected_output = tractable_computation(input_ids[:2],
llama_config).cpu() positions[:2],
llama_config).cpu()
assert torch.allclose(output, expected_output) assert torch.allclose(output, expected_output)
else: else:
return output.cpu() return output.cpu()
def _test_toy_llama(*, use_inductor): @pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
# compare output with and without piecewise compilation # compare output with and without piecewise compilation
llama_config = LlamaConfig(hidden_size=128, llama_config = LlamaConfig(hidden_size=128,
@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor):
assert torch.allclose(outputs[0], outputs[i]) assert torch.allclose(outputs[0], outputs[i])
def test_toy_llama_inductor():
_test_toy_llama(use_inductor=True)
def test_toy_no_inductor():
_test_toy_llama(use_inductor=False)
@torch.inference_mode @torch.inference_mode
def benchmark(): def benchmark():
from triton.testing import do_bench from triton.testing import do_bench

View File

@ -667,42 +667,54 @@ def get_physical_device_indices(devices):
@_nvml() @_nvml()
def wait_for_gpu_memory_to_clear(devices: list[int], def wait_for_gpu_memory_to_clear(*,
threshold_bytes: int, devices: list[int],
threshold_bytes: Optional[int] = None,
threshold_ratio: Optional[float] = None,
timeout_s: float = 120) -> None: timeout_s: float = 120) -> None:
assert threshold_bytes is not None or threshold_ratio is not None
# Use nvml instead of pytorch to reduce measurement error from torch cuda # Use nvml instead of pytorch to reduce measurement error from torch cuda
# context. # context.
devices = get_physical_device_indices(devices) devices = get_physical_device_indices(devices)
start_time = time.time() start_time = time.time()
while True: while True:
output: dict[int, str] = {} output: dict[int, str] = {}
output_raw: dict[int, float] = {} output_raw: dict[int, tuple[float, float]] = {}
for device in devices: for device in devices:
if current_platform.is_rocm(): if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device] dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle) mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10 gb_used = mem_info["vram_used"] / 2**10
gb_total = mem_info["vram_total"] / 2**10
else: else:
dev_handle = nvmlDeviceGetHandleByIndex(device) dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle) mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30 gb_used = mem_info.used / 2**30
output_raw[device] = gb_used gb_total = mem_info.total / 2**30
output[device] = f'{gb_used:.02f}' output_raw[device] = (gb_used, gb_total)
output[device] = f'{gb_used:.02f}/{gb_total:.02f}'
print('gpu memory used (GB): ', end='') print('gpu memory used/total (GiB): ', end='')
for k, v in output.items(): for k, v in output.items():
print(f'{k}={v}; ', end='') print(f'{k}={v}; ', end='')
print('') print('')
if threshold_bytes is not None:
is_free = lambda used, total: used <= threshold_bytes / 2**30
threshold = f"{threshold_bytes/2**30} GiB"
else:
is_free = lambda used, total: used / total <= threshold_ratio
threshold = f"{threshold_ratio:.2f}"
dur_s = time.time() - start_time dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): if all(is_free(used, total) for used, total in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} ' print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}') f'({threshold=}) {dur_s=:.02f}')
break break
if dur_s >= timeout_s: if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after ' raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})') f'{dur_s=:.02f} ({threshold=})')
time.sleep(5) time.sleep(5)

View File

@ -14,6 +14,7 @@ from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors from vllm.utils import weak_ref_tensors
@ -137,7 +138,10 @@ class CUDAPiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation() self.check_for_ending_compilation()
if not entry.use_cudagraph: # Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
if not entry.use_cudagraph or skip_cuda_graphs:
return entry.runnable(*args) return entry.runnable(*args)
if entry.cudagraph is None: if entry.cudagraph is None:

View File

@ -179,7 +179,8 @@ class LLM:
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
override_pooler_config: Optional[PoolerConfig] = None, override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, dict[str, Any]]] = None, compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
"""LLM constructor.""" """LLM constructor."""

View File

@ -94,6 +94,7 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass # set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False
_forward_context: Optional[ForwardContext] = None _forward_context: Optional[ForwardContext] = None
@ -108,11 +109,14 @@ def get_forward_context() -> ForwardContext:
@contextmanager @contextmanager
def set_forward_context(attn_metadata: Any, def set_forward_context(
vllm_config: VllmConfig, attn_metadata: Any,
virtual_engine: int = 0, vllm_config: VllmConfig,
num_tokens: Optional[int] = None, virtual_engine: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None): num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
Here we can inject common logic for every model forward pass. Here we can inject common logic for every model forward pass.
@ -135,7 +139,9 @@ def set_forward_context(attn_metadata: Any,
static_forward_context, static_forward_context,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
dp_metadata=dp_metadata) dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs,
)
try: try:
yield yield

View File

@ -7,7 +7,8 @@ from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata) TorchSDPAMetadata)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
@ -53,7 +54,7 @@ class TorchSDPABackend:
return False return False
class TorchSDPAMetadataBuilderV1: class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec, def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable) -> None: block_table: BlockTable) -> None:
@ -118,9 +119,12 @@ class TorchSDPAMetadataBuilderV1:
return True return True
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def build(self, common_prefix_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata): common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
runner = self.runner runner = self.runner
block_table = self.block_table block_table = self.block_table
seq_lens_np = runner.seq_lens_np[:num_reqs] seq_lens_np = runner.seq_lens_np[:num_reqs]

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, ClassVar, Optional
import numpy as np import numpy as np
import torch import torch
@ -21,13 +21,12 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_cuda(): if current_platform.is_cuda():
@ -306,7 +305,9 @@ def _get_sliding_window_configs(
return sliding_window_configs return sliding_window_configs
class FlashAttentionMetadataBuilder: class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable): block_table: BlockTable):
@ -336,13 +337,14 @@ class FlashAttentionMetadataBuilder:
# populated on first build() call. # populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None self.aot_sliding_window: Optional[tuple[int, int]] = None
def reorder_batch(self, input_batch: "InputBatch", def build(
scheduler_output: "SchedulerOutput") -> bool: self, common_prefix_len: int,
return False common_attn_metadata: CommonAttentionMetadata
) -> FlashAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
@ -496,6 +498,11 @@ class FlashAttentionMetadataBuilder:
) )
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs) return use_cascade_attention(*args, **kwargs)

View File

@ -18,7 +18,8 @@ from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
@ -202,7 +203,7 @@ class FlashInferMetadata:
f" received {self.head_dim}.") f" received {self.head_dim}.")
class FlashInferMetadataBuilder: class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable): block_table: BlockTable):
@ -399,9 +400,11 @@ class FlashInferMetadataBuilder:
kv_data_type=attn_metadata.data_type, kv_data_type=attn_metadata.data_type,
) )
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def build(self, common_prefix_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata): common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
assert self._num_decodes + self._num_prefills == num_reqs assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens + assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens) self._num_prefill_tokens == num_actual_tokens)

View File

@ -15,7 +15,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
is_quantized_kv_cache) is_quantized_kv_cache)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
@ -25,8 +26,6 @@ if current_platform.is_cuda():
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
create_block_mask_compiled = torch.compile(create_block_mask, create_block_mask_compiled = torch.compile(create_block_mask,
@ -256,7 +255,8 @@ class FlexAttentionMetadata:
self.block_mask = self.build_block_mask() self.block_mask = self.build_block_mask()
class FlexAttentionMetadataBuilder: class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable): block_table: BlockTable):
@ -272,13 +272,12 @@ class FlexAttentionMetadataBuilder:
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_table = block_table self.block_table = block_table
def reorder_batch(self, input_batch: "InputBatch", def build(self, common_prefix_len: int,
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata): common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = self.runner.seq_lens_np[:num_reqs].max() max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
@ -332,9 +331,6 @@ class FlexAttentionMetadataBuilder:
) )
return out return out
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class FlexAttentionImpl(AttentionImpl): class FlexAttentionImpl(AttentionImpl):
sliding_window: Optional[tuple[int, int]] sliding_window: Optional[tuple[int, int]]

View File

@ -207,7 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
@ -329,7 +330,7 @@ class MLACommonMetadata(Generic[D]):
M = TypeVar("M", bound=MLACommonMetadata) M = TypeVar("M", bound=MLACommonMetadata)
class MLACommonMetadataBuilder(Generic[M]): class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
@ -450,9 +451,32 @@ class MLACommonMetadataBuilder(Generic[M]):
seq_lens=seq_lens, seq_lens=seq_lens,
) )
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def build_for_cudagraph_capture(
common_prefix_len: int, self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0
self._num_prefill_tokens = 0
return self.build(0, m)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M: common_attn_metadata: CommonAttentionMetadata) -> M:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
assert self._num_decodes + self._num_prefills == num_reqs assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this # Note(simon): be careful about the CPU <> GPU memory movement in this
@ -461,8 +485,11 @@ class MLACommonMetadataBuilder(Generic[M]):
device = self.runner.device device = self.runner.device
block_table = self.block_table block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs] block_table_tensor = block_table.get_device_tensor()[:num_reqs]
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( block_table.slot_mapping[:num_actual_tokens].copy_(
device, non_blocking=True).long() block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
@ -564,8 +591,9 @@ class MLACommonMetadataBuilder(Generic[M]):
decode=decode_metadata, decode=decode_metadata,
) )
def use_cascade_attention(self, *args, **kwargs) -> bool: def can_run_in_cudagraph(
return False self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, ClassVar, Optional
import torch import torch
@ -44,7 +44,7 @@ class FlashMLABackend(MLACommonBackend):
@dataclass @dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata): class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor] tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor num_splits: torch.Tensor
@ -54,14 +54,18 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def __init__(self, runner, kv_cache_spec: AttentionSpec, def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable): block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table) super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config) self.runner.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
def _build_decode(self, block_table_tensor: torch.Tensor, def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \ tile_scheduler_metadata, num_splits = \
@ -71,6 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path 1, # MQA for the decode path
) )
if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
return FlashMLADecodeMetadata( return FlashMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,
seq_lens=seq_lens, seq_lens=seq_lens,

View File

@ -66,7 +66,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def __init__(self, runner, kv_cache_spec: AttentionSpec, def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable): block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table) super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1." "only supports block size 1."

View File

@ -1,15 +1,23 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
import numpy as np
import torch import torch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@dataclass @dataclass
class CommonAttentionMetadata: class CommonAttentionMetadata:
""" """
Attention metadata attributes that can be shared by layers in different KV Per-batch attention metadata, shared across layers and backends.
cache groups and thus having different block table. AttentionMetadataBuilder instances use it to construct per-layer metadata.
""" """
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
@ -18,6 +26,67 @@ class CommonAttentionMetadata:
"""(batch_size,), the length of each request including both computed tokens """(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens""" and newly scheduled tokens"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
M = TypeVar("M")
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
@abstractmethod
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
"""
raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
num_sms: int,
) -> bool:
return False
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""
This method can reorder the batch if desired by the backend.
:return: Has the batch been reordered (default False).
"""
return False
def validate_kv_sharing_target(current_layer_name, target_layer_name, def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context): static_forward_context):

View File

@ -138,15 +138,17 @@ class EagleProposer:
max_query_len = query_lens.max().item() max_query_len = query_lens.max().item()
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens) query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
)
assert self.runner is not None assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups # FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builder.build( attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
) )

View File

@ -16,10 +16,8 @@ from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import AttentionBackend
AttentionMetadataBuilder)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@ -41,7 +39,8 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, is_pin_memory_available) check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, KVCacheConfig, KVCacheSpec,
@ -89,6 +88,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
@ -197,7 +197,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_sizes=[self.cache_config.block_size], block_sizes=[self.cache_config.block_size],
) )
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.compilation_config.level
== CompilationLevel.PIECEWISE == CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager) and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
@ -205,8 +205,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order. # The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed( reversed(self.compilation_config.cudagraph_capture_sizes))
self.vllm_config.compilation_config.cudagraph_capture_sizes))
self.full_cuda_graph = self.compilation_config.full_cuda_graph
# Cache the device properties. # Cache the device properties.
self._init_device_properties() self._init_device_properties()
@ -555,7 +556,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]: ) -> tuple[dict[str, Any], bool, torch.Tensor,
Optional[SpecDecodeMetadata]]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
attention_cuda_graphs: whether attention can run in cudagraph
logits_indices, spec_decode_metadata
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
@ -669,7 +678,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens) query_start_loc=query_start_loc,
seq_lens=seq_lens,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
)
attn_metadata: dict[str, Any] = {} attn_metadata: dict[str, Any] = {}
# Prepare the attention metadata for each KV cache group and make layers # Prepare the attention metadata for each KV cache group and make layers
@ -679,25 +693,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Prepare for cascade attention if enabled & beneficial. # Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0 common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
if self.cascade_attn_enabled: if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len( common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens, num_scheduled_tokens,
scheduler_output. scheduler_output.
num_common_prefix_blocks[kv_cache_group_id], num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec, kv_cache_group_spec.kv_cache_spec,
self.attn_metadata_builders[kv_cache_group_id], builder,
) )
attn_metadata_i = ( attn_metadata_i = (builder.build(
self.attn_metadata_builders[kv_cache_group_id].build( common_prefix_len=common_prefix_len,
num_reqs=num_reqs, common_attn_metadata=common_attn_metadata,
num_actual_tokens=total_num_scheduled_tokens, ))
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
attention_cuda_graphs = all(
b.can_run_in_cudagraph(common_attn_metadata)
for b in self.attn_metadata_builders)
use_spec_decode = len( use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0 scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode: if not use_spec_decode:
@ -726,7 +743,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata)
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
@ -1121,7 +1139,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert self.intermediate_tensors is not None assert self.intermediate_tensors is not None
tp = self.vllm_config.parallel_config.tensor_parallel_size tp = self.vllm_config.parallel_config.tensor_parallel_size
enabled_sp = self.vllm_config.compilation_config.pass_config. \ enabled_sp = self.compilation_config.pass_config. \
enable_sequence_parallelism enable_sequence_parallelism
if enabled_sp: if enabled_sp:
# When sequence parallelism is enabled, we always pad num_tokens # When sequence parallelism is enabled, we always pad num_tokens
@ -1189,8 +1207,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output) return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs. # Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = ( (attn_metadata, attention_cuda_graphs, logits_indices,
self._prepare_inputs(scheduler_output)) spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@ -1203,7 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pad tokens to multiple of tensor_parallel_size when # Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP # enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.compilation_config.pass_config. \ if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1: enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size) num_input_tokens = round_up(num_scheduled_tokens, tp_size)
@ -1255,12 +1273,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True) num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
# Run the decoder. # Run the decoder.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, with set_forward_context(
self.vllm_config, attn_metadata,
num_tokens=num_input_tokens, self.vllm_config,
num_tokens_across_dp=num_tokens_across_dp): num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
):
self.maybe_setup_kv_connector(scheduler_output) self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model( model_output = self.model(
@ -1769,7 +1795,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run( def _dummy_run(
self, self,
num_tokens: int, num_tokens: int,
skip_attn: bool = True, capture_attn_cudagraph: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Padding for DP # Padding for DP
@ -1790,9 +1816,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list, num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32) dtype=np.int32)
if skip_attn: attn_metadata: Optional[dict[str, Any]] = None
attn_metadata: Optional[dict[str, Any]] = None if capture_attn_cudagraph:
else: attn_metadata = {}
query_start_loc = self.query_start_loc[:num_reqs + 1] query_start_loc = self.query_start_loc[:num_reqs + 1]
# Make sure max_model_len is used at the graph capture time. # Make sure max_model_len is used at the graph capture time.
self.seq_lens_np[:num_reqs] = self.max_model_len self.seq_lens_np[:num_reqs] = self.max_model_len
@ -1802,19 +1829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens) query_start_loc=query_start_loc,
seq_lens=seq_lens,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
)
attn_metadata = {}
for kv_cache_group_id, kv_cache_group_spec in enumerate( for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups): self.kv_cache_config.kv_cache_groups):
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build( attn_metadata_i = self.attn_metadata_builders[
num_reqs=num_reqs, kv_cache_group_id].build_for_cudagraph_capture(
num_actual_tokens=num_tokens, common_attn_metadata)
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
))
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
@ -2039,14 +2066,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Capture the large shapes first so that the smaller shapes # Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes. # can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device): with graph_capture(device=self.device):
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph full_cg = self.full_cuda_graph
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
desc="Capturing CUDA graphs", desc="Capturing CUDA graphs",
total=len(self.cudagraph_batch_sizes)): total=len(self.cudagraph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config. for _ in range(
cudagraph_num_of_warmups): self.compilation_config.cudagraph_num_of_warmups):
self._dummy_run(num_tokens, skip_attn=skip_attn) self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
self._dummy_run(num_tokens, skip_attn=skip_attn) self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
@ -2089,20 +2116,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Non-Attention backend is not supported by V1 " "Non-Attention backend is not supported by V1 "
"GPUModelRunner.") "GPUModelRunner.")
if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = attn_backend_i.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is "
f"{attn_backend_name}, FlashAttention version is "
f"{flash_attn_version}.")
block_table_i = self.input_batch.block_table[i] block_table_i = self.input_batch.block_table[i]
attn_metadata_builder_i = attn_backend_i.get_builder_cls()( attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
weakref.proxy(self), kv_cache_spec, block_table_i) weakref.proxy(self),
kv_cache_spec,
block_table_i,
)
if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
self.attn_backends.append(attn_backend_i) self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i) self.attn_metadata_builders.append(attn_metadata_builder_i)
@ -2142,9 +2169,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
to be reshaped to the desired shape before being used by the models. to be reshaped to the desired shape before being used by the models.
Args: Args:
kv_cache_config: The KV cache config kv_cache_config: The KV cache config
Returns: Returns:
dict[str, torch.Tensor]: A map between layer names to their dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache. corresponding memory buffer for KV cache.
""" """
kv_cache_raw_tensors: dict[str, torch.Tensor] = {} kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
@ -2171,11 +2198,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Reshape the KV cache tensors to the desired shape and dtype. Reshape the KV cache tensors to the desired shape and dtype.
Args: Args:
kv_cache_config: The KV cache config kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape. correct size but uninitialized shape.
Returns: Returns:
Dict[str, torch.Tensor]: A map between layer names to their Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache. corresponding memory buffer for KV cache.
""" """
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
@ -2227,7 +2254,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args: Args:
kv_cache_config: The KV cache config kv_cache_config: The KV cache config
Returns: Returns:
Dict[str, torch.Tensor]: A map between layer names to their Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache. corresponding memory buffer for KV cache.
""" """
# Initialize the memory buffer for KV cache # Initialize the memory buffer for KV cache
@ -2245,10 +2272,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches, kv_caches,
) )
bind_kv_cache( bind_kv_cache(kv_caches,
kv_caches, self.compilation_config.static_forward_context,
self.vllm_config.compilation_config.static_forward_context, self.kv_caches)
self.kv_caches)
return kv_caches return kv_caches
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: