[CUDA] Enable full cudagraph for FlashMLA (#18581)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič 2025-06-13 14:12:26 -04:00 committed by GitHub
parent 1015296b79
commit 3597b06a4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 452 additions and 219 deletions

View File

@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
MODEL = "Qwen/Qwen2-1.5B-Instruct"
@contextlib.contextmanager
def temporary_environ(env_vars):
@ -31,64 +32,119 @@ def temporary_environ(env_vars):
os.environ[k] = v
@pytest.fixture(scope="module")
def full_cudagraph_llm():
@pytest.fixture(scope="class")
def llm_pair(request):
model = request.param
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.3,
compilation_config=CompilationConfig(full_cuda_graph=True))
@pytest.fixture(scope="module")
def piecewise_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.6,
compilation_config=CompilationConfig())
def generate_text(llm: LLM, batch_size: int, max_tokens: int):
prompts = ["Hi my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
return llm.generate(prompts, sampling_params)
full = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(full_cuda_graph=True),
)
piecewise = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(),
)
# PyTest caches the fixture values so we use weakref.proxy to enable GC
yield weakref.proxy(full), weakref.proxy(piecewise)
del full
del piecewise
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)
@pytest.mark.parametrize(
"llm_pair",
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite",
"Qwen/Qwen2-1.5B-Instruct"
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FlashAttention 3")
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
(16, 10), (25, 10),
(32, 10), (45, 10),
(64, 10), (8, 5),
(8, 20), (8, 200)])
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
piecewise_llm):
reason="Only Hopper GPUs support FA3 and FlashMLA")
class TestFullCUDAGraph:
"""
Load full cudagraph model and piecewise model once, and at the same time to
reuse them across various test cases.
Use a class such that an llm pair is constructed once for all
batch_size/max_tokens combinations and released immediately after.
Test various batch sizes and max_tokens to ensure that the full cudagraph
compilation works for padded cases too.
Module-scope fixtures would stick around the whole time,
meaning there would be multiple LLM instances hogging memory simultaneously.
"""
piecewise_responses = generate_text(piecewise_llm,
batch_size=batch_size,
max_tokens=max_tokens)
full_cudagraph_responses = generate_text(full_cudagraph_llm,
batch_size=batch_size,
max_tokens=max_tokens)
# Check that all responses are the same
for i in range(len(piecewise_responses)):
assert piecewise_responses[i].outputs[
0].text == full_cudagraph_responses[i].outputs[0].text
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
(1, 10),
(7, 10),
(16, 10),
(25, 10),
(32, 10),
(45, 10),
(64, 10),
(123, 10),
(8, 5),
(8, 30),
])
def test_full_cudagraph(self, batch_size, max_tokens,
llm_pair: tuple[LLM, LLM]):
"""
Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too.
"""
piecewise_llm, full_cudagraph_llm = llm_pair
prompts = ["Hello, my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
# Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
@pytest.mark.parametrize(
"model, supported",
[
("Qwen/Qwen2-1.5B-Instruct", True),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
("deepseek-ai/DeepSeek-V2-Lite", False),
])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(RuntimeError))
llm = LLM(model=model,
max_num_seqs=256,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[64, 256, 512]))
llm.generate(["Hello, my name is"] * 10)
def test_full_cudagraph_with_invalid_backend():
@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend():
"VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph
}), pytest.raises(RuntimeError):
LLM(model=MODEL,
LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(full_cuda_graph=True))

View File

@ -4,7 +4,7 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import pytest
import torch
from torch import nn
from torch.library import Library
@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op
global_counter = 0
@ -76,7 +77,8 @@ class SillyModel(nn.Module):
return x
def _test_simple_piecewise_compile(*, use_inductor):
@pytest.mark.parametrize("use_inductor", [True, False])
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
vllm_config = VllmConfig(compilation_config=CompilationConfig(
@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
), set_forward_context({}, vllm_config=vllm_config):
model(inputs)
@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
def test_simple_piecewise_compile_inductor():
_test_simple_piecewise_compile(use_inductor=True)
def test_simple_piecewise_compile_no_inductor():
_test_simple_piecewise_compile(use_inductor=False)

View File

@ -11,6 +11,7 @@ initialized randomly with a fixed seed.
from dataclasses import dataclass
from typing import Any, Optional
import pytest
import torch
from torch import nn
from torch.library import Library
@ -19,6 +20,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
@ -285,29 +287,32 @@ def run_model(llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()
with set_forward_context({}, vllm_config=vllm_config):
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
output = output.cpu()
output = output.cpu()
if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2], positions[:2],
llama_config).cpu()
if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2],
positions[:2],
llama_config).cpu()
assert torch.allclose(output, expected_output)
else:
return output.cpu()
assert torch.allclose(output, expected_output)
else:
return output.cpu()
def _test_toy_llama(*, use_inductor):
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
# compare output with and without piecewise compilation
llama_config = LlamaConfig(hidden_size=128,
@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor):
assert torch.allclose(outputs[0], outputs[i])
def test_toy_llama_inductor():
_test_toy_llama(use_inductor=True)
def test_toy_no_inductor():
_test_toy_llama(use_inductor=False)
@torch.inference_mode
def benchmark():
from triton.testing import do_bench

View File

@ -667,42 +667,54 @@ def get_physical_device_indices(devices):
@_nvml()
def wait_for_gpu_memory_to_clear(devices: list[int],
threshold_bytes: int,
def wait_for_gpu_memory_to_clear(*,
devices: list[int],
threshold_bytes: Optional[int] = None,
threshold_ratio: Optional[float] = None,
timeout_s: float = 120) -> None:
assert threshold_bytes is not None or threshold_ratio is not None
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
devices = get_physical_device_indices(devices)
start_time = time.time()
while True:
output: dict[int, str] = {}
output_raw: dict[int, float] = {}
output_raw: dict[int, tuple[float, float]] = {}
for device in devices:
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10
gb_total = mem_info["vram_total"] / 2**10
else:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
gb_total = mem_info.total / 2**30
output_raw[device] = (gb_used, gb_total)
output[device] = f'{gb_used:.02f}/{gb_total:.02f}'
print('gpu memory used (GB): ', end='')
print('gpu memory used/total (GiB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')
if threshold_bytes is not None:
is_free = lambda used, total: used <= threshold_bytes / 2**30
threshold = f"{threshold_bytes/2**30} GiB"
else:
is_free = lambda used, total: used / total <= threshold_ratio
threshold = f"{threshold_ratio:.2f}"
dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
if all(is_free(used, total) for used, total in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
f'({threshold=}) {dur_s=:.02f}')
break
if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
f'{dur_s=:.02f} ({threshold=})')
time.sleep(5)

View File

@ -14,6 +14,7 @@ from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
@ -137,7 +138,10 @@ class CUDAPiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
if not entry.use_cudagraph:
# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
if not entry.use_cudagraph or skip_cuda_graphs:
return entry.runnable(*args)
if entry.cudagraph is None:

View File

@ -179,7 +179,8 @@ class LLM:
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None,
**kwargs,
) -> None:
"""LLM constructor."""

View File

@ -94,6 +94,7 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False
_forward_context: Optional[ForwardContext] = None
@ -108,11 +109,14 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None):
def set_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
@ -135,7 +139,9 @@ def set_forward_context(attn_metadata: Any,
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs,
)
try:
yield

View File

@ -7,7 +7,8 @@ from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
@ -53,7 +54,7 @@ class TorchSDPABackend:
return False
class TorchSDPAMetadataBuilderV1:
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable) -> None:
@ -118,9 +119,12 @@ class TorchSDPAMetadataBuilderV1:
return True
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
runner = self.runner
block_table = self.block_table
seq_lens_np = runner.seq_lens_np[:num_reqs]

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Optional
import numpy as np
import torch
@ -21,13 +21,12 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_cuda():
@ -306,7 +305,9 @@ def _get_sliding_window_configs(
return sliding_window_configs
class FlashAttentionMetadataBuilder:
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
@ -336,13 +337,14 @@ class FlashAttentionMetadataBuilder:
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(
self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata
) -> FlashAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
@ -496,6 +498,11 @@ class FlashAttentionMetadataBuilder:
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)

View File

@ -18,7 +18,8 @@ from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
@ -202,7 +203,7 @@ class FlashInferMetadata:
f" received {self.head_dim}.")
class FlashInferMetadataBuilder:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
@ -399,9 +400,11 @@ class FlashInferMetadataBuilder:
kv_data_type=attn_metadata.data_type,
)
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)

View File

@ -15,7 +15,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
@ -25,8 +26,6 @@ if current_platform.is_cuda():
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
create_block_mask_compiled = torch.compile(create_block_mask,
@ -256,7 +255,8 @@ class FlexAttentionMetadata:
self.block_mask = self.build_block_mask()
class FlexAttentionMetadataBuilder:
class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
@ -272,13 +272,12 @@ class FlexAttentionMetadataBuilder:
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
@ -332,9 +331,6 @@ class FlexAttentionMetadataBuilder:
)
return out
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class FlexAttentionImpl(AttentionImpl):
sliding_window: Optional[tuple[int, int]]

View File

@ -207,7 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
@ -329,7 +330,7 @@ class MLACommonMetadata(Generic[D]):
M = TypeVar("M", bound=MLACommonMetadata)
class MLACommonMetadataBuilder(Generic[M]):
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@ -450,9 +451,32 @@ class MLACommonMetadataBuilder(Generic[M]):
seq_lens=seq_lens,
)
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0
self._num_prefill_tokens = 0
return self.build(0, m)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
@ -461,8 +485,11 @@ class MLACommonMetadataBuilder(Generic[M]):
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
@ -564,8 +591,9 @@ class MLACommonMetadataBuilder(Generic[M]):
decode=decode_metadata,
)
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, ClassVar, Optional
import torch
@ -44,7 +44,7 @@ class FlashMLABackend(MLACommonBackend):
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor]
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
@ -54,14 +54,18 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
@ -71,6 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path
)
if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,

View File

@ -66,7 +66,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."

View File

@ -1,15 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
import numpy as np
import torch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@dataclass
class CommonAttentionMetadata:
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
"""
query_start_loc: torch.Tensor
@ -18,6 +26,67 @@ class CommonAttentionMetadata:
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
M = TypeVar("M")
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
@abstractmethod
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
"""
raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
num_sms: int,
) -> bool:
return False
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""
This method can reorder the batch if desired by the backend.
:return: Has the batch been reordered (default False).
"""
return False
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):

View File

@ -138,15 +138,17 @@ class EagleProposer:
max_query_len = query_lens.max().item()
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)

View File

@ -16,10 +16,8 @@ from tqdm import tqdm
import vllm.envs as envs
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder)
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@ -41,7 +39,8 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
@ -89,6 +88,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
@ -197,7 +197,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_sizes=[self.cache_config.block_size],
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level
self.use_cuda_graph = (self.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
@ -205,8 +205,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
reversed(self.compilation_config.cudagraph_capture_sizes))
self.full_cuda_graph = self.compilation_config.full_cuda_graph
# Cache the device properties.
self._init_device_properties()
@ -555,7 +556,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
) -> tuple[dict[str, Any], bool, torch.Tensor,
Optional[SpecDecodeMetadata]]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
attention_cuda_graphs: whether attention can run in cudagraph
logits_indices, spec_decode_metadata
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@ -669,7 +678,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
query_start_loc=query_start_loc,
seq_lens=seq_lens,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
)
attn_metadata: dict[str, Any] = {}
# Prepare the attention metadata for each KV cache group and make layers
@ -679,25 +693,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
self.attn_metadata_builders[kv_cache_group_id],
builder,
)
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata))
attn_metadata_i = (builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
attention_cuda_graphs = all(
b.can_run_in_cudagraph(common_attn_metadata)
for b in self.attn_metadata_builders)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
@ -726,7 +743,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata
return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata)
def _compute_cascade_attn_prefix_len(
self,
@ -1121,7 +1139,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert self.intermediate_tensors is not None
tp = self.vllm_config.parallel_config.tensor_parallel_size
enabled_sp = self.vllm_config.compilation_config.pass_config. \
enabled_sp = self.compilation_config.pass_config. \
enable_sequence_parallelism
if enabled_sp:
# When sequence parallelism is enabled, we always pad num_tokens
@ -1189,8 +1207,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output))
(attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@ -1203,7 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.compilation_config.pass_config. \
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
@ -1255,12 +1273,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
):
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
@ -1769,7 +1795,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
skip_attn: bool = True,
capture_attn_cudagraph: bool = False,
) -> torch.Tensor:
# Padding for DP
@ -1790,9 +1816,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
if skip_attn:
attn_metadata: Optional[dict[str, Any]] = None
else:
attn_metadata: Optional[dict[str, Any]] = None
if capture_attn_cudagraph:
attn_metadata = {}
query_start_loc = self.query_start_loc[:num_reqs + 1]
# Make sure max_model_len is used at the graph capture time.
self.seq_lens_np[:num_reqs] = self.max_model_len
@ -1802,19 +1829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
query_start_loc=query_start_loc,
seq_lens=seq_lens,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
)
attn_metadata = {}
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
))
attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@ -2039,14 +2066,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
full_cg = self.full_cuda_graph
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
desc="Capturing CUDA graphs",
total=len(self.cudagraph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens, skip_attn=skip_attn)
self._dummy_run(num_tokens, skip_attn=skip_attn)
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
@ -2089,20 +2116,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = attn_backend_i.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is "
f"{attn_backend_name}, FlashAttention version is "
f"{flash_attn_version}.")
block_table_i = self.input_batch.block_table[i]
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
weakref.proxy(self), kv_cache_spec, block_table_i)
weakref.proxy(self),
kv_cache_spec,
block_table_i,
)
if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
@ -2142,9 +2169,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
to be reshaped to the desired shape before being used by the models.
Args:
kv_cache_config: The KV cache config
kv_cache_config: The KV cache config
Returns:
dict[str, torch.Tensor]: A map between layer names to their
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
@ -2171,11 +2198,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Reshape the KV cache tensors to the desired shape and dtype.
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_caches: dict[str, torch.Tensor] = {}
@ -2227,7 +2254,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args:
kv_cache_config: The KV cache config
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# Initialize the memory buffer for KV cache
@ -2245,10 +2272,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches,
)
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: