mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 09:55:53 +08:00
[Attention] Make seq_lens_cpu optional in CommonAttentionMetadata to enable true async spec-decode (#29624)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
2e7035dd8c
commit
abe93bce59
@ -106,8 +106,8 @@ def create_common_attn_metadata(
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_cpu=seq_lens_cpu,
|
_seq_lens_cpu=seq_lens_cpu,
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
num_reqs=batch_spec.batch_size,
|
num_reqs=batch_spec.batch_size,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Test that verifies no implicit GPU-CPU synchronization occurs during
|
||||||
|
speculative decoding generation under expected conditions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sync_tracker():
|
||||||
|
"""
|
||||||
|
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
|
||||||
|
lazy init syncs. Prints stack traces immediately when syncs occur.
|
||||||
|
"""
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
|
# Shared counter for cross-process communication (inherited by fork)
|
||||||
|
sync_count = multiprocessing.Value("i", 0)
|
||||||
|
|
||||||
|
# Save original property
|
||||||
|
original_prop = CommonAttentionMetadata.seq_lens_cpu
|
||||||
|
original_fget = original_prop.fget
|
||||||
|
|
||||||
|
# Create tracking wrapper
|
||||||
|
def tracking_seq_lens_cpu(self):
|
||||||
|
if self._seq_lens_cpu is None:
|
||||||
|
# Increment counter
|
||||||
|
with sync_count.get_lock():
|
||||||
|
sync_count.value += 1
|
||||||
|
count = sync_count.value
|
||||||
|
# Print stack trace immediately (shows in subprocess output)
|
||||||
|
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||||
|
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
|
||||||
|
print(f"{'=' * 60}", file=sys.stderr)
|
||||||
|
traceback.print_stack(file=sys.stderr)
|
||||||
|
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||||
|
sys.stderr.flush()
|
||||||
|
return original_fget(self)
|
||||||
|
|
||||||
|
# Apply patch
|
||||||
|
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
|
||||||
|
|
||||||
|
class SyncTracker:
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
return sync_count.value
|
||||||
|
|
||||||
|
def assert_no_sync(self, msg: str = ""):
|
||||||
|
count = sync_count.value
|
||||||
|
assert count == 0, (
|
||||||
|
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
|
||||||
|
f"{count} times. See stack traces above. {msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield SyncTracker()
|
||||||
|
|
||||||
|
# Restore original property
|
||||||
|
CommonAttentionMetadata.seq_lens_cpu = original_prop
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
|
||||||
|
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
|
||||||
|
SPEC_DECODE_CONFIGS = [
|
||||||
|
pytest.param(
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
"nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||||
|
"eagle3",
|
||||||
|
2,
|
||||||
|
id="eagle3-llama",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"eagle618/deepseek-v3-random",
|
||||||
|
"eagle618/eagle-deepseek-v3-random",
|
||||||
|
"eagle",
|
||||||
|
2,
|
||||||
|
id="eagle-mla-deepseek",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,spec_model,method,num_spec_tokens",
|
||||||
|
SPEC_DECODE_CONFIGS,
|
||||||
|
)
|
||||||
|
def test_no_sync_with_spec_decode(
|
||||||
|
sync_tracker,
|
||||||
|
model: str,
|
||||||
|
spec_model: str,
|
||||||
|
method: str,
|
||||||
|
num_spec_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that no implicit GPU-CPU sync occurs during speculative decoding
|
||||||
|
generation.
|
||||||
|
"""
|
||||||
|
# Import vLLM AFTER sync_tracker fixture has applied the patch
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
max_model_len=256,
|
||||||
|
speculative_config={
|
||||||
|
"method": method,
|
||||||
|
"num_speculative_tokens": num_spec_tokens,
|
||||||
|
"model": spec_model,
|
||||||
|
},
|
||||||
|
enforce_eager=True,
|
||||||
|
async_scheduling=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
["Hello, my name is"],
|
||||||
|
SamplingParams(temperature=0, max_tokens=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert len(outputs[0].outputs[0].text) > 0
|
||||||
|
|
||||||
|
del llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
sync_tracker.assert_no_sync()
|
||||||
@ -88,8 +88,8 @@ def forward_attention(
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_cpu=query_start_loc.cpu(),
|
query_start_loc_cpu=query_start_loc.cpu(),
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_cpu=seq_lens.cpu(),
|
_seq_lens_cpu=seq_lens.cpu(),
|
||||||
num_computed_tokens_cpu=context_lens.cpu(),
|
_num_computed_tokens_cpu=context_lens.cpu(),
|
||||||
num_reqs=batch_size,
|
num_reqs=batch_size,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
@ -103,7 +103,7 @@ def create_cross_attention_backend(
|
|||||||
# needed here to know how many tokens to attend to from the cached
|
# needed here to know how many tokens to attend to from the cached
|
||||||
# cross-attention KV cache.
|
# cross-attention KV cache.
|
||||||
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||||
new_metadata.seq_lens_cpu = torch.from_numpy(
|
new_metadata._seq_lens_cpu = torch.from_numpy(
|
||||||
common_attn_metadata.encoder_seq_lens_cpu
|
common_attn_metadata.encoder_seq_lens_cpu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -370,6 +370,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
|
|
||||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||||
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||||
|
|
||||||
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import runtime_checkable
|
from typing_extensions import deprecated, runtime_checkable
|
||||||
|
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
@ -66,11 +66,6 @@ class CommonAttentionMetadata:
|
|||||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||||
|
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
seq_lens_cpu: torch.Tensor
|
|
||||||
"""(batch_size,), the length of each request including both computed tokens
|
|
||||||
and newly scheduled tokens"""
|
|
||||||
|
|
||||||
num_computed_tokens_cpu: torch.Tensor
|
|
||||||
"""(batch_size,), the number of computed tokens for each request"""
|
"""(batch_size,), the number of computed tokens for each request"""
|
||||||
|
|
||||||
num_reqs: int
|
num_reqs: int
|
||||||
@ -81,7 +76,7 @@ class CommonAttentionMetadata:
|
|||||||
max_query_len: int
|
max_query_len: int
|
||||||
"""Longest query in batch"""
|
"""Longest query in batch"""
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
"""Longest context length in batch"""
|
"""Longest context length (may be an upper bound)"""
|
||||||
|
|
||||||
block_table_tensor: torch.Tensor
|
block_table_tensor: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
@ -100,6 +95,40 @@ class CommonAttentionMetadata:
|
|||||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||||
|
|
||||||
|
# WARNING: Deprecated fields. Will be removed in a future release (v0.14.0)
|
||||||
|
_seq_lens_cpu: torch.Tensor | None = None
|
||||||
|
_num_computed_tokens_cpu: torch.Tensor | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecated(
|
||||||
|
"""
|
||||||
|
Prefer using device seq_lens directly to avoid implicit H<>D sync.
|
||||||
|
If a CPU copy is needed, use `seq_lens.cpu()` instead.
|
||||||
|
Will be removed in a future release (v0.14.0)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
def seq_lens_cpu(self) -> torch.Tensor:
|
||||||
|
if self._seq_lens_cpu is None:
|
||||||
|
self._seq_lens_cpu = self.seq_lens.to("cpu")
|
||||||
|
return self._seq_lens_cpu
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecated(
|
||||||
|
"""
|
||||||
|
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
|
||||||
|
async scheduling. If a CPU copy is needed, it can be derived from
|
||||||
|
query_start_loc_cpu and seq_lens.
|
||||||
|
Will be removed in a future release (v0.14.0)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
def num_computed_tokens_cpu(self) -> torch.Tensor:
|
||||||
|
if self._num_computed_tokens_cpu is None:
|
||||||
|
query_seq_lens = (
|
||||||
|
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
|
||||||
|
)
|
||||||
|
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
|
||||||
|
return self._num_computed_tokens_cpu
|
||||||
|
|
||||||
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
||||||
def unpadded(
|
def unpadded(
|
||||||
self, num_actual_tokens: int, num_actual_reqs: int
|
self, num_actual_tokens: int, num_actual_reqs: int
|
||||||
@ -109,8 +138,12 @@ class CommonAttentionMetadata:
|
|||||||
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||||
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
|
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
|
||||||
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
|
if self._seq_lens_cpu is not None
|
||||||
|
else None,
|
||||||
|
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
|
||||||
|
if self._num_computed_tokens_cpu is not None
|
||||||
|
else None,
|
||||||
num_reqs=num_actual_reqs,
|
num_reqs=num_actual_reqs,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
@ -224,14 +257,14 @@ def _make_metadata_with_slice(
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_cpu=seq_lens_cpu,
|
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
||||||
num_reqs=num_requests,
|
num_reqs=num_requests,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
block_table_tensor=block_table_tensor,
|
block_table_tensor=block_table_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
|
_seq_lens_cpu=seq_lens_cpu,
|
||||||
|
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -689,9 +722,7 @@ def make_local_attention_virtual_batches(
|
|||||||
return CommonAttentionMetadata(
|
return CommonAttentionMetadata(
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
|
query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
|
||||||
seq_lens_cpu=seq_lens_cpu,
|
|
||||||
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
|
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
|
||||||
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
|
||||||
num_reqs=len(seq_lens_cpu),
|
num_reqs=len(seq_lens_cpu),
|
||||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||||
max_query_len=seqlens_q_local.max(),
|
max_query_len=seqlens_q_local.max(),
|
||||||
@ -699,6 +730,8 @@ def make_local_attention_virtual_batches(
|
|||||||
block_table_tensor=block_table_local,
|
block_table_tensor=block_table_local,
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
_seq_lens_cpu=seq_lens_cpu,
|
||||||
|
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -719,7 +752,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
|||||||
logits_indices = logits_indices_padded[:num_logits_indices]
|
logits_indices = logits_indices_padded[:num_logits_indices]
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
|
||||||
# Example inputs
|
# Example inputs
|
||||||
# num_reqs: 3
|
# num_reqs: 3
|
||||||
# generation_indices: [14, 18, 19, 27]
|
# generation_indices: [14, 18, 19, 27]
|
||||||
@ -748,9 +780,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
|||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
query_start_loc=decode_query_start_loc,
|
query_start_loc=decode_query_start_loc,
|
||||||
query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
|
query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
|
||||||
seq_lens=seq_lens,
|
seq_lens=common_attn_metadata.seq_lens,
|
||||||
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
|
|
||||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=total_num_decode_tokens,
|
num_actual_tokens=total_num_decode_tokens,
|
||||||
max_query_len=decode_max_query_len,
|
max_query_len=decode_max_query_len,
|
||||||
@ -758,6 +788,8 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
|||||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
|
||||||
|
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||||
)
|
)
|
||||||
return common_attn_metadata
|
return common_attn_metadata
|
||||||
|
|
||||||
|
|||||||
@ -440,16 +440,16 @@ class EagleProposer:
|
|||||||
# of main model.
|
# of main model.
|
||||||
# Increment the sequence lengths.
|
# Increment the sequence lengths.
|
||||||
common_attn_metadata.seq_lens += 1
|
common_attn_metadata.seq_lens += 1
|
||||||
# This is an out-of-place operation to avoid modifying the original tensor.
|
|
||||||
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
|
|
||||||
# For the requests that exceed the max model length, we set the
|
# For the requests that exceed the max model length, we set the
|
||||||
# sequence length to 1 to minimize their overheads in attention.
|
# sequence length to 1 to minimize their overheads in attention.
|
||||||
|
|
||||||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||||
|
|
||||||
common_attn_metadata.num_computed_tokens_cpu = (
|
# Also update the CPU-side shadow; NOTE: this is hacky and should be
|
||||||
common_attn_metadata.seq_lens_cpu - 1
|
# removed in when common_attn_metadata.seq_lens_cpu is deprecated.
|
||||||
)
|
if common_attn_metadata._seq_lens_cpu is not None:
|
||||||
|
common_attn_metadata._seq_lens_cpu += 1
|
||||||
|
if common_attn_metadata._num_computed_tokens_cpu is not None:
|
||||||
|
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||||
|
|
||||||
# Compute the slot mapping.
|
# Compute the slot mapping.
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
@ -656,8 +656,8 @@ class EagleProposer:
|
|||||||
query_start_loc=common_attn_metadata.query_start_loc,
|
query_start_loc=common_attn_metadata.query_start_loc,
|
||||||
seq_lens=common_attn_metadata.seq_lens,
|
seq_lens=common_attn_metadata.seq_lens,
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
|
||||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||||
num_reqs=common_attn_metadata.num_reqs,
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
num_actual_tokens=total_num_tokens,
|
num_actual_tokens=total_num_tokens,
|
||||||
max_query_len=new_query_len_per_req.max().item(),
|
max_query_len=new_query_len_per_req.max().item(),
|
||||||
@ -932,8 +932,8 @@ class EagleProposer:
|
|||||||
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
|
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
|
||||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||||
seq_lens_cpu=new_seq_lens_cpu,
|
_seq_lens_cpu=new_seq_lens_cpu,
|
||||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||||
num_reqs=common_attn_metadata.num_reqs,
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
num_actual_tokens=total_num_tokens,
|
num_actual_tokens=total_num_tokens,
|
||||||
max_query_len=new_query_len_per_req.max().item(),
|
max_query_len=new_query_len_per_req.max().item(),
|
||||||
|
|||||||
@ -168,9 +168,9 @@ def build_attn_metadata(
|
|||||||
query_start_loc=query_start_loc_gpu,
|
query_start_loc=query_start_loc_gpu,
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_cpu=seq_lens_cpu,
|
_seq_lens_cpu=seq_lens_cpu,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
@ -1626,8 +1626,8 @@ class GPUModelRunner(
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_cpu=seq_lens_cpu,
|
_seq_lens_cpu=seq_lens_cpu,
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
num_actual_tokens=num_tokens_padded,
|
num_actual_tokens=num_tokens_padded,
|
||||||
num_reqs=num_reqs_padded,
|
num_reqs=num_reqs_padded,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user