[Attention] Make seq_lens_cpu optional in CommonAttentionMetadata to enable true async spec-decode (#29624)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Lucas Wilkinson 2025-12-09 20:18:10 -05:00 committed by GitHub
parent 2e7035dd8c
commit abe93bce59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 200 additions and 37 deletions

View File

@ -106,8 +106,8 @@ def create_common_attn_metadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""
import multiprocessing
import sys
import traceback
import pytest
import torch
@pytest.fixture
def sync_tracker():
"""
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
"""
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0)
# Save original property
original_prop = CommonAttentionMetadata.seq_lens_cpu
original_fget = original_prop.fget
# Create tracking wrapper
def tracking_seq_lens_cpu(self):
if self._seq_lens_cpu is None:
# Increment counter
with sync_count.get_lock():
sync_count.value += 1
count = sync_count.value
# Print stack trace immediately (shows in subprocess output)
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
traceback.print_stack(file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
sys.stderr.flush()
return original_fget(self)
# Apply patch
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
class SyncTracker:
@property
def count(self) -> int:
return sync_count.value
def assert_no_sync(self, msg: str = ""):
count = sync_count.value
assert count == 0, (
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
f"{count} times. See stack traces above. {msg}"
)
yield SyncTracker()
# Restore original property
CommonAttentionMetadata.seq_lens_cpu = original_prop
torch._dynamo.reset()
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS = [
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
"eagle3",
2,
id="eagle3-llama",
),
pytest.param(
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
"eagle",
2,
id="eagle-mla-deepseek",
),
]
@pytest.mark.parametrize(
"model,spec_model,method,num_spec_tokens",
SPEC_DECODE_CONFIGS,
)
def test_no_sync_with_spec_decode(
sync_tracker,
model: str,
spec_model: str,
method: str,
num_spec_tokens: int,
):
"""
Test that no implicit GPU-CPU sync occurs during speculative decoding
generation.
"""
# Import vLLM AFTER sync_tracker fixture has applied the patch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
llm = LLM(
model=model,
max_model_len=256,
speculative_config={
"method": method,
"num_speculative_tokens": num_spec_tokens,
"model": spec_model,
},
enforce_eager=True,
async_scheduling=True,
)
outputs = llm.generate(
["Hello, my name is"],
SamplingParams(temperature=0, max_tokens=10),
)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].text) > 0
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
sync_tracker.assert_no_sync()

View File

@ -88,8 +88,8 @@ def forward_attention(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(),
_seq_lens_cpu=seq_lens.cpu(),
_num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,

View File

@ -103,7 +103,7 @@ def create_cross_attention_backend(
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata.seq_lens_cpu = torch.from_numpy(
new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu
)

View File

@ -370,6 +370,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)

View File

@ -18,7 +18,7 @@ from typing import (
import numpy as np
import torch
from typing_extensions import runtime_checkable
from typing_extensions import deprecated, runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
@ -66,11 +66,6 @@ class CommonAttentionMetadata:
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
@ -81,7 +76,7 @@ class CommonAttentionMetadata:
max_query_len: int
"""Longest query in batch"""
max_seq_len: int
"""Longest context length in batch"""
"""Longest context length (may be an upper bound)"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
@ -100,6 +95,40 @@ class CommonAttentionMetadata:
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
# WARNING: Deprecated fields. Will be removed in a future release (v0.14.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync.
If a CPU copy is needed, use `seq_lens.cpu()` instead.
Will be removed in a future release (v0.14.0)
"""
)
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
async scheduling. If a CPU copy is needed, it can be derived from
query_start_loc_cpu and seq_lens.
Will be removed in a future release (v0.14.0)
"""
)
def num_computed_tokens_cpu(self) -> torch.Tensor:
if self._num_computed_tokens_cpu is None:
query_seq_lens = (
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
)
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu
# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int
@ -109,8 +138,12 @@ class CommonAttentionMetadata:
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
if self._seq_lens_cpu is not None
else None,
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
if self._num_computed_tokens_cpu is not None
else None,
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
@ -224,14 +257,14 @@ def _make_metadata_with_slice(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
)
@ -689,9 +722,7 @@ def make_local_attention_virtual_batches(
return CommonAttentionMetadata(
query_start_loc_cpu=query_start_loc_cpu,
query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
seq_lens_cpu=seq_lens_cpu,
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
num_reqs=len(seq_lens_cpu),
num_actual_tokens=common_attn_metadata.num_actual_tokens,
max_query_len=seqlens_q_local.max(),
@ -699,6 +730,8 @@ def make_local_attention_virtual_batches(
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
)
@ -719,7 +752,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
logits_indices = logits_indices_padded[:num_logits_indices]
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
# Example inputs
# num_reqs: 3
# generation_indices: [14, 18, 19, 27]
@ -748,9 +780,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=decode_query_start_loc,
query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
seq_lens=common_attn_metadata.seq_lens,
num_reqs=num_reqs,
num_actual_tokens=total_num_decode_tokens,
max_query_len=decode_max_query_len,
@ -758,6 +788,8 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
return common_attn_metadata

View File

@ -440,16 +440,16 @@ class EagleProposer:
# of main model.
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
# This is an out-of-place operation to avoid modifying the original tensor.
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
common_attn_metadata.num_computed_tokens_cpu = (
common_attn_metadata.seq_lens_cpu - 1
)
# Also update the CPU-side shadow; NOTE: this is hacky and should be
# removed in when common_attn_metadata.seq_lens_cpu is deprecated.
if common_attn_metadata._seq_lens_cpu is not None:
common_attn_metadata._seq_lens_cpu += 1
if common_attn_metadata._num_computed_tokens_cpu is not None:
common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping.
if self.uses_mrope:
@ -656,8 +656,8 @@ class EagleProposer:
query_start_loc=common_attn_metadata.query_start_loc,
seq_lens=common_attn_metadata.seq_lens,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
@ -932,8 +932,8 @@ class EagleProposer:
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens_cpu=new_seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
_seq_lens_cpu=new_seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),

View File

@ -168,9 +168,9 @@ def build_attn_metadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
_seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,

View File

@ -1626,8 +1626,8 @@ class GPUModelRunner(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_actual_tokens=num_tokens_padded,
num_reqs=num_reqs_padded,
max_query_len=max_query_len,