Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-06 21:18:16 -07:00
parent 0c56069c7e
commit 6283995a6c
3 changed files with 128 additions and 101 deletions

View File

@ -3,8 +3,10 @@
from dataclasses import dataclass
from typing import Any, Optional
import numba
import numpy as np
import torch
from numba import types
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@ -33,3 +35,58 @@ class InputBatch:
spec_decode_metadata: Optional[SpecDecodeMetadata]
logits_indices: torch.Tensor
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
types.int64[:], # positions
)
],
nopython=True,
cache=True,
)
def prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
positions: np.ndarray, # [num_input_tokens]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)

View File

@ -80,8 +80,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.gpu_block_table import BlockTables
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_worker_states import RequestState, prepare_inputs
from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs
from vllm.v1.worker.gpu_worker_states import RequestState
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

View File

@ -4,12 +4,10 @@
from dataclasses import dataclass
from typing import Optional, Union
import numba
import numpy as np
import torch
import triton
import triton.language as tl
from numba import types
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
@ -49,38 +47,76 @@ class RequestData:
]
class Param:
class SamplingStates:
def __init__(
self,
num_rows_cpu: int,
num_cols: int,
num_rows_gpu: int,
dtype: torch.dtype,
max_num_reqs: int,
max_model_len: int,
max_num_cached_reqs: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
is_scalar: bool = False,
):
self.cpu = torch.zeros(num_rows_cpu,
num_cols,
dtype=dtype,
device="cpu",
pin_memory=pin_memory)
self.np = self.cpu.numpy()
self.gpu = torch.zeros(num_rows_gpu,
num_cols,
dtype=dtype,
device=device)
if is_scalar:
self.cpu.squeeze_(1)
self.np = self.cpu.numpy()
self.gpu.squeeze_(1)
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_cached_reqs = max_num_cached_reqs
self.vocab_size = vocab_size
self.device = device
# TODO(woosuk): Optimize this.
self.gpu_buffer = self.cpu.to(device)
self.temperature = self._make_param(torch.float32)
self.greedy_req_indices: set[int] = set()
self.top_p = self._make_param(torch.float32)
self.top_p_req_indices: set[int] = set()
self.top_k = self._make_param(torch.int32)
self.top_k_req_indices: set[int] = set()
def mirror_to_gpu(self) -> torch.Tensor:
return self.gpu_buffer.copy_(self.cpu, non_blocking=True)
self.frequency_penalties = self._make_param(torch.float32)
self.presence_penalties = self._make_param(torch.float32)
self.repetition_penalties = self._make_param(torch.float32)
self.penalty_req_indices: set[int] = set()
self.generators: dict[int, torch.Generator] = {}
def _make_param(self, dtype: torch.dtype) -> torch.Tensor:
return torch.zeros(self.max_num_reqs, dtype=dtype, device=self.device)
def add_requests(
self,
req_indices: list[int],
sampling_params: list[SamplingParams],
) -> None:
num_reqs = len(req_indices)
for i in range(num_reqs):
req_idx = req_indices[i]
sampling_param = sampling_params[i]
temp = sampling_param.temperature
if temp == 0.0:
self.greedy_req_indices.add(req_idx)
top_p = sampling_param.top_p
if top_p < 1.0:
self.top_p_req_indices.add(req_idx)
top_k = sampling_param.top_k
if 0 < top_k < self.vocab_size:
self.top_k_req_indices.add(req_idx)
else:
top_k = self.vocab_size
if sampling_param.frequency_penalty != 0.0 or sampling_param.presence_penalty != 0.0 or sampling_param.repetition_penalty != 1.0:
self.penalty_req_indices.add(req_idx)
if sampling_param.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_param.seed)
self.generators[req_idx] = generator
def remove_request(self, req_idx: int) -> None:
self.greedy_req_indices.discard(req_idx)
self.top_p_req_indices.discard(req_idx)
self.top_k_req_indices.discard(req_idx)
self.penalty_req_indices.discard(req_idx)
self.generators.pop(req_idx, None)
class RequestState:
@ -128,23 +164,12 @@ class RequestState:
self.num_tokens = self._make_param(torch.int32)
self.num_computed_tokens = self._make_param(torch.int32)
# Sampling-related.
self.temperature = self._make_param(torch.float32)
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = self._make_param(torch.float32)
self.top_p_reqs: set[str] = set()
self.top_k = self._make_param(torch.int32)
self.top_k_reqs: set[str] = set()
self.frequency_penalties = self._make_param(torch.float32)
self.frequency_penalties_reqs: set[str] = set()
self.presence_penalties = self._make_param(torch.float32)
self.presence_penalties_reqs: set[str] = set()
self.repetition_penalties = self._make_param(torch.float32)
self.repetition_penalties_reqs: set[str] = set()
# req_idx -> generator
self.generators: dict[int, torch.Generator] = {}
self.sampling_states = SamplingStates(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_cached_reqs=max_num_cached_reqs,
device=device,
)
def _make_param(
self,
@ -413,58 +438,3 @@ def _prepare_spec_decode_kernel(
sample_start_idx + offset,
mask=offset < draft_len)
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
types.int64[:], # positions
)
],
nopython=True,
cache=True,
)
def prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
positions: np.ndarray, # [num_input_tokens]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)