mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 17:47:13 +08:00
[CUDA] Enable full cudagraph for FlashMLA (#18581)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
parent
1015296b79
commit
3597b06a4f
@ -2,15 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL = "Qwen/Qwen2-1.5B-Instruct"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_environ(env_vars):
|
||||
@ -31,64 +32,119 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def full_cudagraph_llm():
|
||||
@pytest.fixture(scope="class")
|
||||
def llm_pair(request):
|
||||
model = request.param
|
||||
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
return LLM(model=MODEL,
|
||||
gpu_memory_utilization=0.3,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def piecewise_llm():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
return LLM(model=MODEL,
|
||||
gpu_memory_utilization=0.6,
|
||||
compilation_config=CompilationConfig())
|
||||
|
||||
|
||||
def generate_text(llm: LLM, batch_size: int, max_tokens: int):
|
||||
prompts = ["Hi my name is"] * batch_size
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
|
||||
return llm.generate(prompts, sampling_params)
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True),
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(),
|
||||
)
|
||||
|
||||
# PyTest caches the fixture values so we use weakref.proxy to enable GC
|
||||
yield weakref.proxy(full), weakref.proxy(piecewise)
|
||||
del full
|
||||
del piecewise
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=[0],
|
||||
threshold_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_pair",
|
||||
[
|
||||
# Model names for the llm_pair fixture
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
"Qwen/Qwen2-1.5B-Instruct"
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FlashAttention 3")
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
|
||||
(16, 10), (25, 10),
|
||||
(32, 10), (45, 10),
|
||||
(64, 10), (8, 5),
|
||||
(8, 20), (8, 200)])
|
||||
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
|
||||
piecewise_llm):
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
class TestFullCUDAGraph:
|
||||
"""
|
||||
Load full cudagraph model and piecewise model once, and at the same time to
|
||||
reuse them across various test cases.
|
||||
Use a class such that an llm pair is constructed once for all
|
||||
batch_size/max_tokens combinations and released immediately after.
|
||||
|
||||
Test various batch sizes and max_tokens to ensure that the full cudagraph
|
||||
compilation works for padded cases too.
|
||||
Module-scope fixtures would stick around the whole time,
|
||||
meaning there would be multiple LLM instances hogging memory simultaneously.
|
||||
"""
|
||||
piecewise_responses = generate_text(piecewise_llm,
|
||||
batch_size=batch_size,
|
||||
max_tokens=max_tokens)
|
||||
full_cudagraph_responses = generate_text(full_cudagraph_llm,
|
||||
batch_size=batch_size,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
# Check that all responses are the same
|
||||
for i in range(len(piecewise_responses)):
|
||||
assert piecewise_responses[i].outputs[
|
||||
0].text == full_cudagraph_responses[i].outputs[0].text
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
||||
(1, 10),
|
||||
(7, 10),
|
||||
(16, 10),
|
||||
(25, 10),
|
||||
(32, 10),
|
||||
(45, 10),
|
||||
(64, 10),
|
||||
(123, 10),
|
||||
(8, 5),
|
||||
(8, 30),
|
||||
])
|
||||
def test_full_cudagraph(self, batch_size, max_tokens,
|
||||
llm_pair: tuple[LLM, LLM]):
|
||||
"""
|
||||
Test various batch sizes and max_tokens to ensure that the
|
||||
full cudagraph compilation works for padded cases too.
|
||||
"""
|
||||
|
||||
piecewise_llm, full_cudagraph_llm = llm_pair
|
||||
|
||||
prompts = ["Hello, my name is"] * batch_size
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
|
||||
# Check that all responses are the same
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, supported",
|
||||
[
|
||||
("Qwen/Qwen2-1.5B-Instruct", True),
|
||||
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
|
||||
("deepseek-ai/DeepSeek-V2-Lite", False),
|
||||
])
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
def test_lower_max_num_seqs(model, supported):
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}), ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(RuntimeError))
|
||||
|
||||
llm = LLM(model=model,
|
||||
max_num_seqs=256,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
full_cuda_graph=True,
|
||||
cudagraph_capture_sizes=[64, 256, 512]))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
|
||||
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend():
|
||||
"VLLM_FLASH_ATTN_VERSION":
|
||||
"2" #FA2 not supported with full_cuda_graph
|
||||
}), pytest.raises(RuntimeError):
|
||||
LLM(model=MODEL,
|
||||
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
@ -76,7 +77,8 @@ class SillyModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _test_simple_piecewise_compile(*, use_inductor):
|
||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||
def test_simple_piecewise_compile(use_inductor):
|
||||
assert VLLM_USE_V1
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
|
||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
|
||||
model(inputs)
|
||||
|
||||
@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
|
||||
def test_simple_piecewise_compile_inductor():
|
||||
_test_simple_piecewise_compile(use_inductor=True)
|
||||
|
||||
|
||||
def test_simple_piecewise_compile_no_inductor():
|
||||
_test_simple_piecewise_compile(use_inductor=False)
|
||||
|
||||
@ -11,6 +11,7 @@ initialized randomly with a fixed seed.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
@ -19,6 +20,7 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
@ -285,29 +287,32 @@ def run_model(llama_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="").eval().cuda()
|
||||
|
||||
B = 16 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
positions = torch.arange(B).cuda()
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
B = 16 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
positions = torch.arange(B).cuda()
|
||||
|
||||
model(input_ids, positions)
|
||||
model(input_ids[:2], positions[:2])
|
||||
model(input_ids[:1], positions[:1])
|
||||
model(input_ids, positions)
|
||||
model(input_ids[:2], positions[:2])
|
||||
model(input_ids[:1], positions[:1])
|
||||
|
||||
input_ids[:2].zero_()
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
input_ids[:2].zero_()
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
|
||||
output = output.cpu()
|
||||
output = output.cpu()
|
||||
|
||||
if llama_config.tractable_init:
|
||||
expected_output = tractable_computation(input_ids[:2], positions[:2],
|
||||
llama_config).cpu()
|
||||
if llama_config.tractable_init:
|
||||
expected_output = tractable_computation(input_ids[:2],
|
||||
positions[:2],
|
||||
llama_config).cpu()
|
||||
|
||||
assert torch.allclose(output, expected_output)
|
||||
else:
|
||||
return output.cpu()
|
||||
assert torch.allclose(output, expected_output)
|
||||
else:
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def _test_toy_llama(*, use_inductor):
|
||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||
def test_toy_llama(use_inductor: bool):
|
||||
# compare output with and without piecewise compilation
|
||||
|
||||
llama_config = LlamaConfig(hidden_size=128,
|
||||
@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
|
||||
|
||||
def test_toy_llama_inductor():
|
||||
_test_toy_llama(use_inductor=True)
|
||||
|
||||
|
||||
def test_toy_no_inductor():
|
||||
_test_toy_llama(use_inductor=False)
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def benchmark():
|
||||
from triton.testing import do_bench
|
||||
|
||||
@ -667,42 +667,54 @@ def get_physical_device_indices(devices):
|
||||
|
||||
|
||||
@_nvml()
|
||||
def wait_for_gpu_memory_to_clear(devices: list[int],
|
||||
threshold_bytes: int,
|
||||
def wait_for_gpu_memory_to_clear(*,
|
||||
devices: list[int],
|
||||
threshold_bytes: Optional[int] = None,
|
||||
threshold_ratio: Optional[float] = None,
|
||||
timeout_s: float = 120) -> None:
|
||||
assert threshold_bytes is not None or threshold_ratio is not None
|
||||
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
||||
# context.
|
||||
devices = get_physical_device_indices(devices)
|
||||
start_time = time.time()
|
||||
while True:
|
||||
output: dict[int, str] = {}
|
||||
output_raw: dict[int, float] = {}
|
||||
output_raw: dict[int, tuple[float, float]] = {}
|
||||
for device in devices:
|
||||
if current_platform.is_rocm():
|
||||
dev_handle = amdsmi_get_processor_handles()[device]
|
||||
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
|
||||
gb_used = mem_info["vram_used"] / 2**10
|
||||
gb_total = mem_info["vram_total"] / 2**10
|
||||
else:
|
||||
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
||||
gb_used = mem_info.used / 2**30
|
||||
output_raw[device] = gb_used
|
||||
output[device] = f'{gb_used:.02f}'
|
||||
gb_total = mem_info.total / 2**30
|
||||
output_raw[device] = (gb_used, gb_total)
|
||||
output[device] = f'{gb_used:.02f}/{gb_total:.02f}'
|
||||
|
||||
print('gpu memory used (GB): ', end='')
|
||||
print('gpu memory used/total (GiB): ', end='')
|
||||
for k, v in output.items():
|
||||
print(f'{k}={v}; ', end='')
|
||||
print('')
|
||||
|
||||
if threshold_bytes is not None:
|
||||
is_free = lambda used, total: used <= threshold_bytes / 2**30
|
||||
threshold = f"{threshold_bytes/2**30} GiB"
|
||||
else:
|
||||
is_free = lambda used, total: used / total <= threshold_ratio
|
||||
threshold = f"{threshold_ratio:.2f}"
|
||||
|
||||
dur_s = time.time() - start_time
|
||||
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
|
||||
if all(is_free(used, total) for used, total in output_raw.values()):
|
||||
print(f'Done waiting for free GPU memory on devices {devices=} '
|
||||
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
|
||||
f'({threshold=}) {dur_s=:.02f}')
|
||||
break
|
||||
|
||||
if dur_s >= timeout_s:
|
||||
raise ValueError(f'Memory of devices {devices=} not free after '
|
||||
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
||||
f'{dur_s=:.02f} ({threshold=})')
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
@ -137,7 +138,10 @@ class CUDAPiecewiseBackend:
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
if not entry.use_cudagraph:
|
||||
# Skip CUDA graphs if this entry doesn't use them OR
|
||||
# if we're supposed to skip them globally
|
||||
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
|
||||
if not entry.use_cudagraph or skip_cuda_graphs:
|
||||
return entry.runnable(*args)
|
||||
|
||||
if entry.cudagraph is None:
|
||||
|
||||
@ -179,7 +179,8 @@ class LLM:
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
|
||||
compilation_config: Optional[Union[int, dict[str, Any],
|
||||
CompilationConfig]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""LLM constructor."""
|
||||
|
||||
@ -94,6 +94,7 @@ class ForwardContext:
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
skip_cuda_graphs: bool = False
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
@ -108,11 +109,14 @@ def get_forward_context() -> ForwardContext:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None):
|
||||
def set_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
skip_cuda_graphs: bool = False,
|
||||
):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
@ -135,7 +139,9 @@ def set_forward_context(attn_metadata: Any,
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata)
|
||||
dp_metadata=dp_metadata,
|
||||
skip_cuda_graphs=skip_cuda_graphs,
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
|
||||
@ -7,7 +7,8 @@ from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
|
||||
TorchSDPAMetadata)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
@ -53,7 +54,7 @@ class TorchSDPABackend:
|
||||
return False
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilderV1:
|
||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable) -> None:
|
||||
@ -118,9 +119,12 @@ class TorchSDPAMetadataBuilderV1:
|
||||
|
||||
return True
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
runner = self.runner
|
||||
block_table = self.block_table
|
||||
seq_lens_np = runner.seq_lens_np[:num_reqs]
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -21,13 +21,12 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@ -306,7 +305,9 @@ def _get_sliding_window_configs(
|
||||
return sliding_window_configs
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder:
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
@ -336,13 +337,14 @@ class FlashAttentionMetadataBuilder:
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
def build(
|
||||
self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata
|
||||
) -> FlashAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
@ -496,6 +498,11 @@ class FlashAttentionMetadataBuilder:
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
@ -18,7 +18,8 @@ from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@ -202,7 +203,7 @@ class FlashInferMetadata:
|
||||
f" received {self.head_dim}.")
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder:
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
@ -399,9 +400,11 @@ class FlashInferMetadataBuilder:
|
||||
kv_data_type=attn_metadata.data_type,
|
||||
)
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
assert (self._num_decode_tokens +
|
||||
self._num_prefill_tokens == num_actual_tokens)
|
||||
|
||||
@ -15,7 +15,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@ -25,8 +26,6 @@ if current_platform.is_cuda():
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
create_block_mask_compiled = torch.compile(create_block_mask,
|
||||
@ -256,7 +255,8 @@ class FlexAttentionMetadata:
|
||||
self.block_mask = self.build_block_mask()
|
||||
|
||||
|
||||
class FlexAttentionMetadataBuilder:
|
||||
class FlexAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
@ -272,13 +272,12 @@ class FlexAttentionMetadataBuilder:
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
@ -332,9 +331,6 @@ class FlexAttentionMetadataBuilder:
|
||||
)
|
||||
return out
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FlexAttentionImpl(AttentionImpl):
|
||||
sliding_window: Optional[tuple[int, int]]
|
||||
|
||||
@ -207,7 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@ -329,7 +330,7 @@ class MLACommonMetadata(Generic[D]):
|
||||
M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder(Generic[M]):
|
||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@ -450,9 +451,32 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with MLA.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
# Update state usually set in reorder_batch.
|
||||
self._num_decodes = m.num_reqs
|
||||
self._num_decode_tokens = m.num_actual_tokens
|
||||
self._num_prefills = 0
|
||||
self._num_prefill_tokens = 0
|
||||
return self.build(0, m)
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
@ -461,8 +485,11 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
device = self.runner.device
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
@ -564,8 +591,9 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -44,7 +44,7 @@ class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor]
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@ -54,14 +54,18 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table)
|
||||
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
@ -71,6 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
self.cg_buf_num_splits = num_splits
|
||||
else:
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
tile_scheduler_metadata.size())
|
||||
self.cg_buf_tile_scheduler_metadata.\
|
||||
copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
|
||||
@ -66,7 +66,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table)
|
||||
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
|
||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||
"only supports block size 1."
|
||||
|
||||
|
||||
@ -1,15 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import abc
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
"""
|
||||
Attention metadata attributes that can be shared by layers in different KV
|
||||
cache groups and thus having different block table.
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
@ -18,6 +26,67 @@ class CommonAttentionMetadata:
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
full_cudagraph_supported: ClassVar[bool] = False
|
||||
|
||||
@abstractmethod
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
Central method that builds attention metadata.
|
||||
Some builders (MLA) require reorder_batch to be called prior to build.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
"""
|
||||
Can this batch (with given metadata) use CUDA Graphs for attention.
|
||||
"""
|
||||
return False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
Build attention metadata for CUDA graph capture. Uses build by default.
|
||||
Subclasses that override this method should call self.build or
|
||||
super().build_for_cudagraph_capture.
|
||||
"""
|
||||
return self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
def use_cascade_attention(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
num_sms: int,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
This method can reorder the batch if desired by the backend.
|
||||
:return: Has the batch been reordered (default False).
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||
static_forward_context):
|
||||
|
||||
@ -138,15 +138,17 @@ class EagleProposer:
|
||||
max_query_len = query_lens.max().item()
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
|
||||
query_start_loc=cu_num_tokens,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
@ -16,10 +16,8 @@ from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
@ -41,7 +39,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
||||
check_use_alibi, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
KVCacheConfig, KVCacheSpec,
|
||||
@ -89,6 +88,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
@ -197,7 +197,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
self.use_cuda_graph = (self.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
@ -205,8 +205,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||
# The batch sizes in the config are in descending order.
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
reversed(self.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
self.full_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
|
||||
# Cache the device properties.
|
||||
self._init_device_properties()
|
||||
@ -555,7 +556,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
|
||||
) -> tuple[dict[str, Any], bool, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata]]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
attention_cuda_graphs: whether attention can run in cudagraph
|
||||
logits_indices, spec_decode_metadata
|
||||
]
|
||||
"""
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@ -669,7 +678,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
@ -679,25 +693,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
builder = self.attn_metadata_builders[kv_cache_group_id]
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
self.attn_metadata_builders[kv_cache_group_id],
|
||||
builder,
|
||||
)
|
||||
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata))
|
||||
attn_metadata_i = (builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
attention_cuda_graphs = all(
|
||||
b.can_run_in_cudagraph(common_attn_metadata)
|
||||
for b in self.attn_metadata_builders)
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
@ -726,7 +743,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return attn_metadata, logits_indices, spec_decode_metadata
|
||||
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata)
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -1121,7 +1139,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert self.intermediate_tensors is not None
|
||||
|
||||
tp = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
enabled_sp = self.vllm_config.compilation_config.pass_config. \
|
||||
enabled_sp = self.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism
|
||||
if enabled_sp:
|
||||
# When sequence parallelism is enabled, we always pad num_tokens
|
||||
@ -1189,8 +1207,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
@ -1203,7 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if self.vllm_config.compilation_config.pass_config. \
|
||||
if self.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
from vllm.utils import round_up
|
||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||
@ -1255,12 +1273,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
|
||||
# Some attention backends only support CUDA Graphs in pure decode.
|
||||
# If attention doesn't support CUDA Graphs for this batch, but we
|
||||
# compiled with full CUDA graphs, we have to skip them entirely.
|
||||
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp):
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs,
|
||||
):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
model_output = self.model(
|
||||
@ -1769,7 +1795,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
skip_attn: bool = True,
|
||||
capture_attn_cudagraph: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Padding for DP
|
||||
@ -1790,9 +1816,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
|
||||
if skip_attn:
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
else:
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
if capture_attn_cudagraph:
|
||||
attn_metadata = {}
|
||||
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
# Make sure max_model_len is used at the graph capture time.
|
||||
self.seq_lens_np[:num_reqs] = self.max_model_len
|
||||
@ -1802,19 +1829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
)
|
||||
|
||||
attn_metadata = {}
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
attn_metadata_i = self.attn_metadata_builders[
|
||||
kv_cache_group_id].build_for_cudagraph_capture(
|
||||
common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
@ -2039,14 +2066,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with graph_capture(device=self.device):
|
||||
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
|
||||
full_cg = self.full_cuda_graph
|
||||
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
|
||||
desc="Capturing CUDA graphs",
|
||||
total=len(self.cudagraph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, skip_attn=skip_attn)
|
||||
self._dummy_run(num_tokens, skip_attn=skip_attn)
|
||||
for _ in range(
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
|
||||
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
@ -2089,20 +2116,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
"Non-Attention backend is not supported by V1 "
|
||||
"GPUModelRunner.")
|
||||
|
||||
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||
attn_backend_name = attn_backend_i.__name__
|
||||
flash_attn_version = get_flash_attn_version()
|
||||
if attn_backend_name != "FlashAttentionBackend" or \
|
||||
flash_attn_version != 3:
|
||||
raise ValueError(
|
||||
f"full_cuda_graph is only supported with "
|
||||
f"FA3. Current attention backend is "
|
||||
f"{attn_backend_name}, FlashAttention version is "
|
||||
f"{flash_attn_version}.")
|
||||
|
||||
block_table_i = self.input_batch.block_table[i]
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
weakref.proxy(self), kv_cache_spec, block_table_i)
|
||||
weakref.proxy(self),
|
||||
kv_cache_spec,
|
||||
block_table_i,
|
||||
)
|
||||
|
||||
if (self.full_cuda_graph
|
||||
and not attn_metadata_builder_i.full_cudagraph_supported):
|
||||
raise ValueError(
|
||||
f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
|
||||
f"full_cuda_graph or use a different attention backend.")
|
||||
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
|
||||
@ -2142,9 +2169,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
to be reshaped to the desired shape before being used by the models.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
@ -2171,11 +2198,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
Reshape the KV cache tensors to the desired shape and dtype.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
correct size but uninitialized shape.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
@ -2227,7 +2254,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
# Initialize the memory buffer for KV cache
|
||||
@ -2245,10 +2272,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_caches,
|
||||
)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user