From 62d23b3006826d9b1e98c9530b87b5fa9213b614 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 31 Aug 2025 21:00:16 -0700 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/tree_attn.py | 22 ++------- vllm/v1/attention/backends/utils.py | 63 ------------------------- vllm/v1/attention/backends/xformers.py | 18 ++----- vllm/v1/worker/gpu_model_runner.py | 36 +++++++------- vllm/v1/worker/gpu_worker_states.py | 3 +- 5 files changed, 30 insertions(+), 112 deletions(-) diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index b96d957a150b5..d0b163fc9bed2 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,26 +4,21 @@ import ast from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - -from vllm import _custom_ops as ops - logger = init_logger(__name__) @@ -183,13 +178,6 @@ class TreeAttentionMetadataBuilder( device=device, ) - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills( - input_batch, - scheduler_output, - decode_threshold=self.tree_attn_bias.shape[0]) - def build( self, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 011a90ece01bd..79fcd928393f9 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -684,69 +684,6 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) -def reorder_batch_to_split_decodes_and_prefills( - input_batch: "InputBatch", - scheduler_output: "SchedulerOutput", - decode_threshold: int = 1, -) -> bool: - """ - Reorders the batch to split into prefill and decode requests; places all - requests with <= decode_threshold tokens at the front of the batch. - - Returns: - True if the batch was modified, False otherwise. - """ - # We now want to reorder the batch so that the "decode" requests are at - # the front and the "prefill" requests are at the back using the least - # amount of swaps possible. (NOTE for now we loosely use "decode" to mean - # requests where attention is likely memory-bound and "prefill" to mean - # requests where attention is likely compute-bound, TODO(lucas): figure out - # a better naming here) - decodes = [] - prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the TritonMLA._forward_decode only supports - # num_tokens = 1 - if num_tokens <= decode_threshold: - decodes.append(i) - num_decode_tokens += num_tokens - else: - prefills.append(i) - num_prefill_tokens += num_tokens - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - decode_idx = decodes[num_decodes - i] - if decode_idx < num_decodes: - break - - input_batch.swap_states(prefills[i - 1], decode_idx) - modified_batch = True - - return modified_batch - - KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ ('logits_indices_padded', Optional[torch.Tensor], None), ('num_logits_indices', int, 0), diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 7f888c1135743..9edc27b9e2da0 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch @@ -12,9 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec try: @@ -26,10 +26,6 @@ try: except ImportError: XFORMERS_AVAILABLE = False -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm import _custom_ops as ops logger = init_logger(__name__) @@ -210,12 +206,6 @@ class XFormersAttentionMetadataBuilder( self._num_decodes = 0 self._num_decode_tokens = 0 - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) - def build( self, common_prefix_len: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ab388384208b4..3892208f44ffa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -462,6 +462,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _init_mrope_positions(self, req_id: str) -> None: req_idx = self.requests.req_id_to_index[req_id] req_data = self.requests.req_data[req_idx] + prompt_len = self.requests.num_prompt_tokens.np[req_idx] + prompt_token_ids = self.requests.token_ids.np[req_idx, :prompt_len] image_grid_thw = [] video_grid_thw = [] @@ -483,7 +485,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_data.mrope_positions, req_data.mrope_position_delta = \ MRotaryEmbedding.get_input_positions_tensor( - req_data.prompt_token_ids, + prompt_token_ids, hf_config=self.model_config.hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, @@ -905,7 +907,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # list of tuple (mm_hash, position_info) mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_data = self.requests.req_data[req_id] + req_idx = self.requests.req_id_to_index[req_id] + req_data = self.requests.req_data[req_idx] for mm_input_id in encoder_input_ids: mm_hash = req_data.mm_hashes[mm_input_id] @@ -1259,11 +1262,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.kv_connector_no_forward(scheduler_output, self.vllm_config) - if self.cache_config.kv_sharing_fast_prefill: - assert not self.input_batch.num_prompt_logprobs, ( - "--kv-sharing-fast-prefill produces incorrect logprobs for " - "prompt tokens, tokens, please disable it when the requests " - "need prompt logprobs") + # if self.cache_config.kv_sharing_fast_prefill: + # assert not self.input_batch.num_prompt_logprobs, ( + # "--kv-sharing-fast-prefill produces incorrect logprobs for " + # "prompt tokens, tokens, please disable it when the requests " + # "need prompt logprobs") # Prepare the decoder inputs. input_batch = self._prepare_inputs(scheduler_output) @@ -1296,7 +1299,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + mm_embeds = self._gather_mm_embeddings(input_batch) else: mm_embeds = [] @@ -1328,7 +1331,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] + positions = self.mrope_positions.gpu[:, :num_input_tokens] else: positions = self.positions.gpu[:num_input_tokens] @@ -1448,7 +1451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: - num_nans_in_logits = self._get_nans_in_logits(logits) + num_nans_in_logits = self._get_nans_in_logits( + input_batch.req_ids, logits) # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. @@ -1488,14 +1492,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.speculative_config: assert input_batch.spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, + input_batch, valid_sampled_token_ids, sampling_metadata, hidden_states, sample_hidden_states, aux_hidden_states, - input_batch.spec_decode_metadata, - input_batch.spec_decode_common_attn_metadata, ) self._draft_req_ids = input_batch.req_ids @@ -1889,16 +1891,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_nans_in_logits( self, - input_batch: InputBatch, + req_ids: list[str], logits: Optional[torch.Tensor], ) -> dict[str, int]: try: if logits is None: - return {req_id: 0 for req_id in input_batch.req_ids} + return {req_id: 0 for req_id in req_ids} num_nans_in_logits = {} num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() - for i, req_id in enumerate(input_batch.req_ids): + for i, req_id in enumerate(req_ids): num_nans_in_logits[req_id] = (int(num_nans_for_index[i]) if num_nans_for_index is not None and i < logits.shape[0] else 0) @@ -2092,7 +2094,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] + positions = self.mrope_positions.gpu[:, :num_tokens] else: positions = self.positions.gpu[:num_tokens] diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index 626129349648c..ae3829291413d 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -110,7 +110,7 @@ class RequestState: self.is_spec_decode = is_spec_decode self.pooling_params = None self.block_sizes = block_sizes - self.num_prompt_logprobs = {} + self.num_prompt_logprobs: dict[int, int] = {} self.req_id_to_index: dict[str, int] = {} self.index_to_req_id: dict[int, str] = {} @@ -378,6 +378,7 @@ def _make_sampling_metadata_kernel( tl.store(dst_repetition_penalties + batch_idx, repetition_penalties) +@triton.jit def _prepare_spec_decode_kernel( query_start_loc, # [B + 1] cu_num_draft_tokens, # [B]