vllm/vllm/v1/worker/gpu_worker_states.py
Woosuk Kwon 6283995a6c minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 21:18:16 -07:00

441 lines
16 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
_MAX_SPEC_LEN = 32
@dataclass
class RequestData:
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
mm_hashes: list[str]
# M-RoPE (only for Qwen2/2.5-VL)
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
]
class SamplingStates:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_cached_reqs: int,
vocab_size: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_cached_reqs = max_num_cached_reqs
self.vocab_size = vocab_size
self.device = device
self.temperature = self._make_param(torch.float32)
self.greedy_req_indices: set[int] = set()
self.top_p = self._make_param(torch.float32)
self.top_p_req_indices: set[int] = set()
self.top_k = self._make_param(torch.int32)
self.top_k_req_indices: set[int] = set()
self.frequency_penalties = self._make_param(torch.float32)
self.presence_penalties = self._make_param(torch.float32)
self.repetition_penalties = self._make_param(torch.float32)
self.penalty_req_indices: set[int] = set()
self.generators: dict[int, torch.Generator] = {}
def _make_param(self, dtype: torch.dtype) -> torch.Tensor:
return torch.zeros(self.max_num_reqs, dtype=dtype, device=self.device)
def add_requests(
self,
req_indices: list[int],
sampling_params: list[SamplingParams],
) -> None:
num_reqs = len(req_indices)
for i in range(num_reqs):
req_idx = req_indices[i]
sampling_param = sampling_params[i]
temp = sampling_param.temperature
if temp == 0.0:
self.greedy_req_indices.add(req_idx)
top_p = sampling_param.top_p
if top_p < 1.0:
self.top_p_req_indices.add(req_idx)
top_k = sampling_param.top_k
if 0 < top_k < self.vocab_size:
self.top_k_req_indices.add(req_idx)
else:
top_k = self.vocab_size
if sampling_param.frequency_penalty != 0.0 or sampling_param.presence_penalty != 0.0 or sampling_param.repetition_penalty != 1.0:
self.penalty_req_indices.add(req_idx)
if sampling_param.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_param.seed)
self.generators[req_idx] = generator
def remove_request(self, req_idx: int) -> None:
self.greedy_req_indices.discard(req_idx)
self.top_p_req_indices.discard(req_idx)
self.top_k_req_indices.discard(req_idx)
self.penalty_req_indices.discard(req_idx)
self.generators.pop(req_idx, None)
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
max_num_cached_reqs: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_cached_reqs = max_num_cached_reqs
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.is_spec_decode = is_spec_decode
self.pooling_params = None
self.block_sizes = block_sizes
self.num_prompt_logprobs: dict[int, int] = {}
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs))
# Request states.
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self.token_ids = self._make_param(
num_cols=self.max_model_len,
dtype=torch.int32,
cpu_only=True,
)
self.num_prompt_tokens = self._make_param(torch.int32)
self.num_tokens = self._make_param(torch.int32)
self.num_computed_tokens = self._make_param(torch.int32)
self.sampling_states = SamplingStates(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_cached_reqs=max_num_cached_reqs,
device=device,
)
def _make_param(
self,
dtype: torch.dtype,
num_cols: int = 1,
cpu_only: bool = False,
) -> Param:
return Param(
self.max_num_cached_reqs,
num_cols,
self.max_num_reqs if not cpu_only else 0,
dtype,
self.device,
self.pin_memory,
is_scalar=num_cols == 1,
)
@property
def num_cached_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
) -> None:
assert len(self.free_indices) > 0, "No free space in GPU worker states"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
prompt_len = len(prompt_token_ids)
self.num_prompt_tokens.np[req_idx] = prompt_len
self.num_tokens.np[req_idx] = prompt_len
self.token_ids.np[req_idx, :prompt_len] = prompt_token_ids
self.num_computed_tokens.np[req_idx] = num_computed_tokens
self.temperature.np[req_idx] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
# NOTE: Be careful about division by zero.
self.greedy_reqs.add(req_id)
elif sampling_params.sampling_type == SamplingType.RANDOM:
self.random_reqs.add(req_id)
self.top_p.np[req_idx] = sampling_params.top_p
if sampling_params.top_p < 1.0:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.frequency_penalties.np[
req_idx] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties.np[
req_idx] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
self.generators[req_idx] = generator
def append_token_ids(
self,
req_idx: int,
token_ids: Union[list[int], np.ndarray],
) -> None:
start_idx = self.num_tokens.np[req_idx]
end_idx = start_idx + len(token_ids)
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
self.num_tokens.np[req_idx] = end_idx
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_idx, None)
def make_sampling_metadata(
self,
batch_idx_to_req_idx: torch.Tensor,
) -> SamplingMetadata:
batch_size = batch_idx_to_req_idx.shape[0]
if self.top_p_reqs:
top_p_buffer = self.top_p.mirror_to_gpu()
top_p = self.top_p.gpu
else:
top_p_buffer = self.top_p.gpu_buffer
top_p = None
if self.top_k_reqs:
top_k_buffer = self.top_k.mirror_to_gpu()
top_k = self.top_k.gpu
else:
top_k_buffer = self.top_k.gpu_buffer
top_k = None
# TODO(woosuk): Use UVA to optimize CPU -> GPU copy.
_make_sampling_metadata_kernel[(batch_size, )](
batch_idx_to_req_idx,
self.temperature.mirror_to_gpu(),
self.temperature.gpu,
top_p_buffer,
self.top_p.gpu,
top_k_buffer,
self.top_k.gpu,
self.frequency_penalties.mirror_to_gpu(),
self.frequency_penalties.gpu,
self.presence_penalties.mirror_to_gpu(),
self.presence_penalties.gpu,
self.repetition_penalties.mirror_to_gpu(),
self.repetition_penalties.gpu,
num_warps=1,
num_stages=1,
)
no_penalties = not (self.frequency_penalties_reqs
or self.presence_penalties_reqs
or self.repetition_penalties_reqs)
return SamplingMetadata(
temperature=self.temperature.gpu[:batch_size],
all_greedy=not self.random_reqs,
all_random=not self.greedy_reqs,
top_p=top_p,
top_k=top_k,
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
presence_penalties=self.presence_penalties.gpu[:batch_size],
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
no_penalties=no_penalties,
# TODO
generators={},
token_ids=None,
num_tokens=None,
num_prompt_tokens=None,
max_num_logprobs=None,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=None,
)
def make_spec_decode_metadata(
self,
query_start_loc: torch.Tensor,
cu_num_draft_tokens: torch.Tensor,
cu_num_draft_tokens_np: np.ndarray,
input_ids: torch.Tensor,
) -> SpecDecodeMetadata:
batch_size = query_start_loc.shape[0] - 1
total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1]
logits_indices = torch.empty(total_num_draft_tokens + batch_size,
dtype=torch.int32,
device=self.device)
target_logits_indices = torch.empty(total_num_draft_tokens,
dtype=torch.int32,
device=self.device)
bonus_logits_indices = torch.empty(batch_size,
dtype=torch.int32,
device=self.device)
_prepare_spec_decode_kernel[(batch_size, )](
query_start_loc,
cu_num_draft_tokens,
logits_indices,
target_logits_indices,
bonus_logits_indices,
BLOCK_SIZE=triton.next_power_of_2(_MAX_SPEC_LEN + 1),
)
draft_token_ids = input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
return SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=cu_num_draft_tokens_np.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@triton.jit
def _make_sampling_metadata_kernel(
batch_idx_to_req_idx, # [batch_size]
src_temperature,
dst_temperature,
src_top_p,
dst_top_p,
src_top_k,
dst_top_k,
src_frequency_penalties,
dst_frequency_penalties,
src_presence_penalties,
dst_presence_penalties,
src_repetition_penalties,
dst_repetition_penalties,
):
batch_idx = tl.program_id(0)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
temperature = tl.load(src_temperature + req_idx)
tl.store(dst_temperature + batch_idx, temperature)
top_p = tl.load(src_top_p + req_idx)
tl.store(dst_top_p + batch_idx, top_p)
top_k = tl.load(src_top_k + req_idx)
tl.store(dst_top_k + batch_idx, top_k)
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
presence_penalties = tl.load(src_presence_penalties + req_idx)
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
@triton.jit
def _prepare_spec_decode_kernel(
query_start_loc, # [B + 1]
cu_num_draft_tokens, # [B]
logits_indices, # [N + B]
target_logits_indices, # [N]
bonus_logits_indices, # [B]
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
if batch_idx == 0:
draft_start_idx = 0
else:
draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1)
draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx)
draft_len = draft_end_idx - draft_start_idx
sample_len = draft_len + 1
q_end_idx = tl.load(query_start_loc + batch_idx + 1)
sample_start_idx = draft_start_idx + batch_idx
sample_end_idx = sample_start_idx + sample_len
offset = tl.arange(0, BLOCK_SIZE)
tl.store(logits_indices + sample_start_idx + offset,
q_end_idx - sample_len + offset,
mask=offset < sample_len)
tl.store(target_logits_indices + draft_start_idx + offset,
sample_start_idx + offset,
mask=offset < draft_len)
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)