mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 07:45:01 +08:00
[Attention] Make seq_lens_cpu optional in CommonAttentionMetadata to enable true async spec-decode (#29624)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
2e7035dd8c
commit
abe93bce59
@ -106,8 +106,8 @@ def create_common_attn_metadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=batch_spec.batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test that verifies no implicit GPU-CPU synchronization occurs during
|
||||
speculative decoding generation under expected conditions.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sync_tracker():
|
||||
"""
|
||||
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
|
||||
lazy init syncs. Prints stack traces immediately when syncs occur.
|
||||
"""
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
# Shared counter for cross-process communication (inherited by fork)
|
||||
sync_count = multiprocessing.Value("i", 0)
|
||||
|
||||
# Save original property
|
||||
original_prop = CommonAttentionMetadata.seq_lens_cpu
|
||||
original_fget = original_prop.fget
|
||||
|
||||
# Create tracking wrapper
|
||||
def tracking_seq_lens_cpu(self):
|
||||
if self._seq_lens_cpu is None:
|
||||
# Increment counter
|
||||
with sync_count.get_lock():
|
||||
sync_count.value += 1
|
||||
count = sync_count.value
|
||||
# Print stack trace immediately (shows in subprocess output)
|
||||
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
|
||||
print(f"{'=' * 60}", file=sys.stderr)
|
||||
traceback.print_stack(file=sys.stderr)
|
||||
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
return original_fget(self)
|
||||
|
||||
# Apply patch
|
||||
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
|
||||
|
||||
class SyncTracker:
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return sync_count.value
|
||||
|
||||
def assert_no_sync(self, msg: str = ""):
|
||||
count = sync_count.value
|
||||
assert count == 0, (
|
||||
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
|
||||
f"{count} times. See stack traces above. {msg}"
|
||||
)
|
||||
|
||||
yield SyncTracker()
|
||||
|
||||
# Restore original property
|
||||
CommonAttentionMetadata.seq_lens_cpu = original_prop
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
|
||||
SPEC_DECODE_CONFIGS = [
|
||||
pytest.param(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
"eagle3",
|
||||
2,
|
||||
id="eagle3-llama",
|
||||
),
|
||||
pytest.param(
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
"eagle",
|
||||
2,
|
||||
id="eagle-mla-deepseek",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,spec_model,method,num_spec_tokens",
|
||||
SPEC_DECODE_CONFIGS,
|
||||
)
|
||||
def test_no_sync_with_spec_decode(
|
||||
sync_tracker,
|
||||
model: str,
|
||||
spec_model: str,
|
||||
method: str,
|
||||
num_spec_tokens: int,
|
||||
):
|
||||
"""
|
||||
Test that no implicit GPU-CPU sync occurs during speculative decoding
|
||||
generation.
|
||||
"""
|
||||
# Import vLLM AFTER sync_tracker fixture has applied the patch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
max_model_len=256,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"num_speculative_tokens": num_spec_tokens,
|
||||
"model": spec_model,
|
||||
},
|
||||
enforce_eager=True,
|
||||
async_scheduling=True,
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
["Hello, my name is"],
|
||||
SamplingParams(temperature=0, max_tokens=10),
|
||||
)
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert len(outputs[0].outputs[0].text) > 0
|
||||
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
sync_tracker.assert_no_sync()
|
||||
@ -88,8 +88,8 @@ def forward_attention(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc.cpu(),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens.cpu(),
|
||||
num_computed_tokens_cpu=context_lens.cpu(),
|
||||
_seq_lens_cpu=seq_lens.cpu(),
|
||||
_num_computed_tokens_cpu=context_lens.cpu(),
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
@ -103,7 +103,7 @@ def create_cross_attention_backend(
|
||||
# needed here to know how many tokens to attend to from the cached
|
||||
# cross-attention KV cache.
|
||||
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||
new_metadata.seq_lens_cpu = torch.from_numpy(
|
||||
new_metadata._seq_lens_cpu = torch.from_numpy(
|
||||
common_attn_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
|
||||
@ -370,6 +370,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||
|
||||
@ -18,7 +18,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import runtime_checkable
|
||||
from typing_extensions import deprecated, runtime_checkable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils.math_utils import cdiv
|
||||
@ -66,11 +66,6 @@ class CommonAttentionMetadata:
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_cpu: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
num_computed_tokens_cpu: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
@ -81,7 +76,7 @@ class CommonAttentionMetadata:
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
max_seq_len: int
|
||||
"""Longest context length in batch"""
|
||||
"""Longest context length (may be an upper bound)"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
@ -100,6 +95,40 @@ class CommonAttentionMetadata:
|
||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||
|
||||
# WARNING: Deprecated fields. Will be removed in a future release (v0.14.0)
|
||||
_seq_lens_cpu: torch.Tensor | None = None
|
||||
_num_computed_tokens_cpu: torch.Tensor | None = None
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
Prefer using device seq_lens directly to avoid implicit H<>D sync.
|
||||
If a CPU copy is needed, use `seq_lens.cpu()` instead.
|
||||
Will be removed in a future release (v0.14.0)
|
||||
"""
|
||||
)
|
||||
def seq_lens_cpu(self) -> torch.Tensor:
|
||||
if self._seq_lens_cpu is None:
|
||||
self._seq_lens_cpu = self.seq_lens.to("cpu")
|
||||
return self._seq_lens_cpu
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
|
||||
async scheduling. If a CPU copy is needed, it can be derived from
|
||||
query_start_loc_cpu and seq_lens.
|
||||
Will be removed in a future release (v0.14.0)
|
||||
"""
|
||||
)
|
||||
def num_computed_tokens_cpu(self) -> torch.Tensor:
|
||||
if self._num_computed_tokens_cpu is None:
|
||||
query_seq_lens = (
|
||||
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
|
||||
)
|
||||
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
|
||||
return self._num_computed_tokens_cpu
|
||||
|
||||
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
||||
def unpadded(
|
||||
self, num_actual_tokens: int, num_actual_reqs: int
|
||||
@ -109,8 +138,12 @@ class CommonAttentionMetadata:
|
||||
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
|
||||
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
|
||||
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
|
||||
if self._seq_lens_cpu is not None
|
||||
else None,
|
||||
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
|
||||
if self._num_computed_tokens_cpu is not None
|
||||
else None,
|
||||
num_reqs=num_actual_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=self.max_query_len,
|
||||
@ -224,14 +257,14 @@ def _make_metadata_with_slice(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_requests,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
)
|
||||
|
||||
|
||||
@ -689,9 +722,7 @@ def make_local_attention_virtual_batches(
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
|
||||
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||
num_reqs=len(seq_lens_cpu),
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
max_query_len=seqlens_q_local.max(),
|
||||
@ -699,6 +730,8 @@ def make_local_attention_virtual_batches(
|
||||
block_table_tensor=block_table_local,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||
)
|
||||
|
||||
|
||||
@ -719,7 +752,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
logits_indices = logits_indices_padded[:num_logits_indices]
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
# Example inputs
|
||||
# num_reqs: 3
|
||||
# generation_indices: [14, 18, 19, 27]
|
||||
@ -748,9 +780,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=decode_query_start_loc,
|
||||
query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
|
||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_decode_tokens,
|
||||
max_query_len=decode_max_query_len,
|
||||
@ -758,6 +788,8 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
)
|
||||
return common_attn_metadata
|
||||
|
||||
|
||||
@ -440,16 +440,16 @@ class EagleProposer:
|
||||
# of main model.
|
||||
# Increment the sequence lengths.
|
||||
common_attn_metadata.seq_lens += 1
|
||||
# This is an out-of-place operation to avoid modifying the original tensor.
|
||||
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
|
||||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
|
||||
common_attn_metadata.num_computed_tokens_cpu = (
|
||||
common_attn_metadata.seq_lens_cpu - 1
|
||||
)
|
||||
# Also update the CPU-side shadow; NOTE: this is hacky and should be
|
||||
# removed in when common_attn_metadata.seq_lens_cpu is deprecated.
|
||||
if common_attn_metadata._seq_lens_cpu is not None:
|
||||
common_attn_metadata._seq_lens_cpu += 1
|
||||
if common_attn_metadata._num_computed_tokens_cpu is not None:
|
||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||
|
||||
# Compute the slot mapping.
|
||||
if self.uses_mrope:
|
||||
@ -656,8 +656,8 @@ class EagleProposer:
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
@ -932,8 +932,8 @@ class EagleProposer:
|
||||
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
|
||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
seq_lens_cpu=new_seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||
_seq_lens_cpu=new_seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
|
||||
@ -168,9 +168,9 @@ def build_attn_metadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
max_seq_len=max_seq_len,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
@ -1626,8 +1626,8 @@ class GPUModelRunner(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_actual_tokens=num_tokens_padded,
|
||||
num_reqs=num_reqs_padded,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user