Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-31 21:00:16 -07:00
parent ba1a58f51b
commit 62d23b3006
5 changed files with 30 additions and 112 deletions

View File

@ -4,26 +4,21 @@
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@ -183,13 +178,6 @@ class TreeAttentionMetadataBuilder(
device=device,
)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.tree_attn_bias.shape[0])
def build(
self,
common_prefix_len: int,

View File

@ -684,69 +684,6 @@ def split_decodes_and_prefills(
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the back using the least
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
# requests where attention is likely memory-bound and "prefill" to mean
# requests where attention is likely compute-bound, TODO(lucas): figure out
# a better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
return modified_batch
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
@ -12,9 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
@ -26,10 +26,6 @@ try:
except ImportError:
XFORMERS_AVAILABLE = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@ -210,12 +206,6 @@ class XFormersAttentionMetadataBuilder(
self._num_decodes = 0
self._num_decode_tokens = 0
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def build(
self,
common_prefix_len: int,

View File

@ -462,6 +462,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _init_mrope_positions(self, req_id: str) -> None:
req_idx = self.requests.req_id_to_index[req_id]
req_data = self.requests.req_data[req_idx]
prompt_len = self.requests.num_prompt_tokens.np[req_idx]
prompt_token_ids = self.requests.token_ids.np[req_idx, :prompt_len]
image_grid_thw = []
video_grid_thw = []
@ -483,7 +485,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_data.mrope_positions, req_data.mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
req_data.prompt_token_ids,
prompt_token_ids,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
@ -905,7 +907,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_data = self.requests.req_data[req_id]
req_idx = self.requests.req_id_to_index[req_id]
req_data = self.requests.req_data[req_idx]
for mm_input_id in encoder_input_ids:
mm_hash = req_data.mm_hashes[mm_input_id]
@ -1259,11 +1262,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests "
"need prompt logprobs")
# if self.cache_config.kv_sharing_fast_prefill:
# assert not self.input_batch.num_prompt_logprobs, (
# "--kv-sharing-fast-prefill produces incorrect logprobs for "
# "prompt tokens, tokens, please disable it when the requests "
# "need prompt logprobs")
# Prepare the decoder inputs.
input_batch = self._prepare_inputs(scheduler_output)
@ -1296,7 +1299,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
mm_embeds = self._gather_mm_embeddings(input_batch)
else:
mm_embeds = []
@ -1328,7 +1331,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
positions = self.mrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
@ -1448,7 +1451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
num_nans_in_logits = self._get_nans_in_logits(
input_batch.req_ids, logits)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
@ -1488,14 +1492,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.speculative_config:
assert input_batch.spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
input_batch,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
input_batch.spec_decode_metadata,
input_batch.spec_decode_common_attn_metadata,
)
self._draft_req_ids = input_batch.req_ids
@ -1889,16 +1891,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_nans_in_logits(
self,
input_batch: InputBatch,
req_ids: list[str],
logits: Optional[torch.Tensor],
) -> dict[str, int]:
try:
if logits is None:
return {req_id: 0 for req_id in input_batch.req_ids}
return {req_id: 0 for req_id in req_ids}
num_nans_in_logits = {}
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
for i, req_id in enumerate(input_batch.req_ids):
for i, req_id in enumerate(req_ids):
num_nans_in_logits[req_id] = (int(num_nans_for_index[i])
if num_nans_for_index is not None
and i < logits.shape[0] else 0)
@ -2092,7 +2094,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
positions = self.mrope_positions.gpu[:, :num_tokens]
else:
positions = self.positions.gpu[:num_tokens]

View File

@ -110,7 +110,7 @@ class RequestState:
self.is_spec_decode = is_spec_decode
self.pooling_params = None
self.block_sizes = block_sizes
self.num_prompt_logprobs = {}
self.num_prompt_logprobs: dict[int, int] = {}
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
@ -378,6 +378,7 @@ def _make_sampling_metadata_kernel(
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
@triton.jit
def _prepare_spec_decode_kernel(
query_start_loc, # [B + 1]
cu_num_draft_tokens, # [B]