Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 14:52:23 -07:00
parent 82e591f7eb
commit 8407fa02ed
10 changed files with 153 additions and 43 deletions

View File

@ -206,12 +206,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
)
if H < MAX_HEADS:
# Extract the subsets of the outputs
returned_lse = lse[:, :H].contiguous(
) if self.need_to_return_lse_for_decode else lse
out = out[:, :H]
if self.need_to_return_lse_for_decode:
lse = lse[:, :H].contiguous()
return out, returned_lse
return out, lse
def _forward_decode(
self,

View File

@ -4,21 +4,26 @@
import ast
from dataclasses import dataclass
from typing import Optional
from typing import TYPE_CHECKING, Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@ -179,6 +184,13 @@ class TreeAttentionMetadataBuilder(
device=device,
)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.tree_attn_bias.shape[0])
def build(
self,
common_prefix_len: int,

View File

@ -703,6 +703,69 @@ def split_decodes_and_prefills(
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the back using the least
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
# requests where attention is likely memory-bound and "prefill" to mean
# requests where attention is likely compute-bound, TODO(lucas): figure out
# a better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if it's not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
return modified_batch
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import TYPE_CHECKING, ClassVar, Optional
import torch
@ -12,9 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
@ -26,6 +26,10 @@ try:
except ImportError:
XFORMERS_AVAILABLE = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@ -209,6 +213,13 @@ class XFormersAttentionMetadataBuilder(
self._num_decodes = 0
self._num_decode_tokens = 0
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def build(
self,
common_prefix_len: int,

View File

@ -6,14 +6,38 @@ from typing import Optional
import torch
from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
temperature: Optional[torch.Tensor]
all_greedy: bool
all_random: bool
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
generators: dict[int, torch.Generator]
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: list[list[int]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]
# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]
# Loaded logits processors
logitsprocs: LogitsProcessors

View File

@ -90,9 +90,9 @@ class Sampler(nn.Module):
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# # Apply logits processors which can impact greedy sampling
# for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
# logits = processor.apply(logits)
# Apply logits processors which can impact greedy sampling
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
@ -167,10 +167,10 @@ class Sampler(nn.Module):
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# # Apply logits processors that only apply to random sampling
# # (argmax invariant)
# for processor in sampling_metadata.logitsprocs.argmax_invariant:
# logits = processor.apply(logits)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled, processed_logprobs = self.topk_topp_sampler(

View File

@ -79,6 +79,9 @@ class GPUModelRunner:
)
self.sampler = Sampler()
def get_supported_tasks(self) -> tuple[str]:
return ("generate", )
def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:

View File

@ -9,6 +9,7 @@ import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.utils import CpuGpuBuffer
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
@ -69,21 +70,21 @@ class RequestState:
)
# Sampling parameters.
self.temperature = self._make_buffer(self.max_num_reqs, torch.float32)
self.top_p = self._make_buffer(self.max_num_reqs, torch.float32)
self.top_k = self._make_buffer(self.max_num_reqs, torch.int32)
self.seeds = self._make_buffer(self.max_num_reqs, torch.int64)
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
def _make_buffer(self, size, dtype: torch.dtype) -> "Buffer":
return Buffer(size,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
@property
def num_reqs(self) -> int:
@ -217,27 +218,24 @@ def _append_token_ids(
num_tokens[req_idx] = end_idx
class Buffer:
class Param:
def __init__(
self,
*args,
size: int,
dtype: torch.dtype,
pin_memory: bool,
device: torch.device,
pin_memory: bool,
):
# NOTE(woosuk): Unlike CpuGpuBuffer, the Numpy array and CPU tensor
# in this class do not share the same storage.
self.np = np.zeros(*args, dtype=dtype)
self.cpu = torch.zeros(
*args,
self.buffer = CpuGpuBuffer(
size,
dtype=dtype,
pin_memory=pin_memory,
device=device,
pin_memory=pin_memory,
)
self.gpu = self.cpu.to(device)
self.np = np.zeros_like(self.buffer.np)
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
n = x.shape[0]
self.cpu[:n] = x
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)

View File

@ -423,7 +423,6 @@ class Worker(WorkerBase):
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return "generate"
return self.model_runner.get_supported_tasks()
@torch.inference_mode()

View File

@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState
_SAMPLING_EPS = 1e-5