From 8407fa02ed07f1349ea6f4774ca363f9d1920b48 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Sep 2025 14:52:23 -0700 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/mla/cutlass_mla.py | 7 +-- vllm/v1/attention/backends/tree_attn.py | 22 +++++-- vllm/v1/attention/backends/utils.py | 63 +++++++++++++++++++ vllm/v1/attention/backends/xformers.py | 19 ++++-- vllm/v1/sample/metadata.py | 26 +++++++- vllm/v1/sample/sampler.py | 14 ++--- vllm/v1/worker/gpu/model_runner.py | 3 + vllm/v1/worker/gpu/states.py | 40 ++++++------ vllm/v1/worker/gpu_worker.py | 1 - vllm/v1/worker/tpu_input_batch.py | 1 + 10 files changed, 153 insertions(+), 43 deletions(-) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 21be17a750df4..ae534f3207b51 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -206,12 +206,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): ) if H < MAX_HEADS: - # Extract the subsets of the outputs - returned_lse = lse[:, :H].contiguous( - ) if self.need_to_return_lse_for_decode else lse out = out[:, :H] + if self.need_to_return_lse_for_decode: + lse = lse[:, :H].contiguous() - return out, returned_lse + return out, lse def _forward_decode( self, diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 6c7feab57be83..10238f36455d2 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,21 +4,26 @@ import ast from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch -from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm import _custom_ops as ops + logger = init_logger(__name__) @@ -179,6 +184,13 @@ class TreeAttentionMetadataBuilder( device=device, ) + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return reorder_batch_to_split_decodes_and_prefills( + input_batch, + scheduler_output, + decode_threshold=self.tree_attn_bias.shape[0]) + def build( self, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7d80f09d9959b..63326d19194f0 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -703,6 +703,69 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) +def reorder_batch_to_split_decodes_and_prefills( + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + decode_threshold: int = 1, +) -> bool: + """ + Reorders the batch to split into prefill and decode requests; places all + requests with <= decode_threshold tokens at the front of the batch. + + Returns: + True if the batch was modified, False otherwise. + """ + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the back using the least + # amount of swaps possible. (NOTE for now we loosely use "decode" to mean + # requests where attention is likely memory-bound and "prefill" to mean + # requests where attention is likely compute-bound, TODO(lucas): figure out + # a better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if it's not, + # we should update this to something like < 8 in the future but + # currently the TritonMLA._forward_decode only supports + # num_tokens = 1 + if num_tokens <= decode_threshold: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + return modified_batch + + KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ ('logits_indices_padded', Optional[torch.Tensor], None), ('num_logits_indices', int, 0), diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 371220a9f72b7..a6ca334912353 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar, Optional import torch @@ -12,9 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec try: @@ -26,6 +26,10 @@ try: except ImportError: XFORMERS_AVAILABLE = False +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm import _custom_ops as ops logger = init_logger(__name__) @@ -209,6 +213,13 @@ class XFormersAttentionMetadataBuilder( self._num_decodes = 0 self._num_decode_tokens = 0 + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return reorder_batch_to_split_decodes_and_prefills( + input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + def build( self, common_prefix_len: int, diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index af8179baca0ed..9d6a87cea3d07 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -6,14 +6,38 @@ from typing import Optional import torch +from vllm.v1.sample.logits_processor import LogitsProcessors + @dataclass class SamplingMetadata: - temperature: torch.Tensor + temperature: Optional[torch.Tensor] + all_greedy: bool + all_random: bool top_p: Optional[torch.Tensor] top_k: Optional[torch.Tensor] + generators: dict[int, torch.Generator] + # None means no logprobs, 0 means sampled token logprobs only max_num_logprobs: Optional[int] + + no_penalties: bool + prompt_token_ids: Optional[torch.Tensor] + frequency_penalties: torch.Tensor + presence_penalties: torch.Tensor + repetition_penalties: torch.Tensor + + output_token_ids: list[list[int]] + + # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, + # vocab size). + allowed_token_ids_mask: Optional[torch.Tensor] + + # req_index -> bad_words_token_ids + bad_words_token_ids: dict[int, list[list[int]]] + + # Loaded logits processors + logitsprocs: LogitsProcessors diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 4fb6654ef00cb..546531a91610f 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -90,9 +90,9 @@ class Sampler(nn.Module): # Apply bad words exclusion. logits = self.apply_bad_words(logits, sampling_metadata) - # # Apply logits processors which can impact greedy sampling - # for processor in sampling_metadata.logitsprocs.non_argmax_invariant: - # logits = processor.apply(logits) + # Apply logits processors which can impact greedy sampling + for processor in sampling_metadata.logitsprocs.non_argmax_invariant: + logits = processor.apply(logits) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) @@ -167,10 +167,10 @@ class Sampler(nn.Module): # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) - # # Apply logits processors that only apply to random sampling - # # (argmax invariant) - # for processor in sampling_metadata.logitsprocs.argmax_invariant: - # logits = processor.apply(logits) + # Apply logits processors that only apply to random sampling + # (argmax invariant) + for processor in sampling_metadata.logitsprocs.argmax_invariant: + logits = processor.apply(logits) # Apply top_k and/or top_p. random_sampled, processed_logprobs = self.topk_topp_sampler( diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 7435e3eceb69f..c83a79a3bea06 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -79,6 +79,9 @@ class GPUModelRunner: ) self.sampler = Sampler() + def get_supported_tasks(self) -> tuple[str]: + return ("generate", ) + def load_model(self, *args, **kwargs) -> None: time_before_load = time.perf_counter() with DeviceMemoryProfiler() as m: diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 173d1857e68af..25d4bf808cedf 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -9,6 +9,7 @@ import numpy as np import torch from vllm.sampling_params import SamplingParams +from vllm.v1.utils import CpuGpuBuffer _NP_INT64_MIN = np.iinfo(np.int64).min _NP_INT64_MAX = np.iinfo(np.int64).max @@ -69,21 +70,21 @@ class RequestState: ) # Sampling parameters. - self.temperature = self._make_buffer(self.max_num_reqs, torch.float32) - self.top_p = self._make_buffer(self.max_num_reqs, torch.float32) - self.top_k = self._make_buffer(self.max_num_reqs, torch.int32) - self.seeds = self._make_buffer(self.max_num_reqs, torch.int64) + self.temperature = self._make_param(self.max_num_reqs, torch.float32) + self.top_p = self._make_param(self.max_num_reqs, torch.float32) + self.top_k = self._make_param(self.max_num_reqs, torch.int32) + self.seeds = self._make_param(self.max_num_reqs, torch.int64) self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) # -1 means no logprobs are requested. self.num_logprobs.fill(-1) self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) - def _make_buffer(self, size, dtype: torch.dtype) -> "Buffer": - return Buffer(size, - dtype=dtype, - pin_memory=self.pin_memory, - device=self.device) + def _make_param(self, size: int, dtype: torch.dtype) -> "Param": + return Param(size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory) @property def num_reqs(self) -> int: @@ -217,27 +218,24 @@ def _append_token_ids( num_tokens[req_idx] = end_idx -class Buffer: +class Param: def __init__( self, - *args, + size: int, dtype: torch.dtype, - pin_memory: bool, device: torch.device, + pin_memory: bool, ): - # NOTE(woosuk): Unlike CpuGpuBuffer, the Numpy array and CPU tensor - # in this class do not share the same storage. - self.np = np.zeros(*args, dtype=dtype) - self.cpu = torch.zeros( - *args, + self.buffer = CpuGpuBuffer( + size, dtype=dtype, - pin_memory=pin_memory, device=device, + pin_memory=pin_memory, ) - self.gpu = self.cpu.to(device) + self.np = np.zeros_like(self.buffer.np) def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor: n = x.shape[0] - self.cpu[:n] = x - return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) + self.buffer.np[:n] = x + return self.buffer.copy_to_gpu(n) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 7c0ee599040bd..a44abf6466529 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -423,7 +423,6 @@ class Worker(WorkerBase): return self.model_runner.get_model() def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - return "generate" return self.model_runner.get_supported_tasks() @torch.inference_mode() diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 01d1904778278..dfa54d0ad83b6 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.block_table import MultiGroupBlockTable +from vllm.v1.worker.gpu_input_batch import CachedRequestState _SAMPLING_EPS = 1e-5