mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-25 03:07:06 +08:00
441 lines
16 KiB
Python
441 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from typing_extensions import deprecated
|
|
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.multimodal.inputs import (MultiModalKwargsItem,
|
|
MultiModalKwargsItems, PlaceholderRange)
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingParams, SamplingType
|
|
from vllm.v1.sample.logits_processor import LogitsProcessors
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
|
|
_MAX_SPEC_LEN = 32
|
|
|
|
|
|
@dataclass
|
|
class RequestData:
|
|
|
|
mm_kwargs: list[MultiModalKwargsItem]
|
|
mm_positions: list[PlaceholderRange]
|
|
sampling_params: Optional[SamplingParams]
|
|
pooling_params: Optional[PoolingParams]
|
|
|
|
mm_hashes: list[str]
|
|
# M-RoPE (only for Qwen2/2.5-VL)
|
|
mrope_positions: Optional[torch.Tensor] = None
|
|
mrope_position_delta: Optional[int] = None
|
|
|
|
lora_request: Optional[LoRARequest] = None
|
|
|
|
# Temporary back-compatibility for plugins that define model runner
|
|
@property
|
|
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
|
"removed in v0.13. Please use `mm_kwargs` instead.")
|
|
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
|
return [
|
|
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
|
|
]
|
|
|
|
|
|
class SamplingStates:
|
|
|
|
def __init__(
|
|
self,
|
|
max_num_reqs: int,
|
|
max_model_len: int,
|
|
max_num_cached_reqs: int,
|
|
vocab_size: int,
|
|
device: torch.device,
|
|
):
|
|
self.max_num_reqs = max_num_reqs
|
|
self.max_model_len = max_model_len
|
|
self.max_num_cached_reqs = max_num_cached_reqs
|
|
self.vocab_size = vocab_size
|
|
self.device = device
|
|
|
|
self.temperature = self._make_param(torch.float32)
|
|
self.greedy_req_indices: set[int] = set()
|
|
self.top_p = self._make_param(torch.float32)
|
|
self.top_p_req_indices: set[int] = set()
|
|
self.top_k = self._make_param(torch.int32)
|
|
self.top_k_req_indices: set[int] = set()
|
|
|
|
self.frequency_penalties = self._make_param(torch.float32)
|
|
self.presence_penalties = self._make_param(torch.float32)
|
|
self.repetition_penalties = self._make_param(torch.float32)
|
|
self.penalty_req_indices: set[int] = set()
|
|
|
|
self.generators: dict[int, torch.Generator] = {}
|
|
|
|
def _make_param(self, dtype: torch.dtype) -> torch.Tensor:
|
|
return torch.zeros(self.max_num_reqs, dtype=dtype, device=self.device)
|
|
|
|
def add_requests(
|
|
self,
|
|
req_indices: list[int],
|
|
sampling_params: list[SamplingParams],
|
|
) -> None:
|
|
num_reqs = len(req_indices)
|
|
for i in range(num_reqs):
|
|
req_idx = req_indices[i]
|
|
sampling_param = sampling_params[i]
|
|
|
|
temp = sampling_param.temperature
|
|
if temp == 0.0:
|
|
self.greedy_req_indices.add(req_idx)
|
|
|
|
top_p = sampling_param.top_p
|
|
if top_p < 1.0:
|
|
self.top_p_req_indices.add(req_idx)
|
|
top_k = sampling_param.top_k
|
|
if 0 < top_k < self.vocab_size:
|
|
self.top_k_req_indices.add(req_idx)
|
|
else:
|
|
top_k = self.vocab_size
|
|
|
|
if sampling_param.frequency_penalty != 0.0 or sampling_param.presence_penalty != 0.0 or sampling_param.repetition_penalty != 1.0:
|
|
self.penalty_req_indices.add(req_idx)
|
|
|
|
if sampling_param.sampling_type == SamplingType.RANDOM_SEED:
|
|
generator = torch.Generator(device=self.device)
|
|
generator.manual_seed(sampling_param.seed)
|
|
self.generators[req_idx] = generator
|
|
|
|
def remove_request(self, req_idx: int) -> None:
|
|
self.greedy_req_indices.discard(req_idx)
|
|
self.top_p_req_indices.discard(req_idx)
|
|
self.top_k_req_indices.discard(req_idx)
|
|
self.penalty_req_indices.discard(req_idx)
|
|
self.generators.pop(req_idx, None)
|
|
|
|
|
|
class RequestState:
|
|
|
|
def __init__(
|
|
self,
|
|
max_num_reqs: int,
|
|
max_model_len: int,
|
|
max_num_batched_tokens: int,
|
|
max_num_cached_reqs: int,
|
|
device: torch.device,
|
|
pin_memory: bool,
|
|
vocab_size: int,
|
|
block_sizes: list[int], # The block_size of each kv cache group
|
|
logitsprocs: Optional[LogitsProcessors] = None,
|
|
is_spec_decode: bool = False,
|
|
is_pooling_model: bool = False,
|
|
):
|
|
self.max_num_reqs = max_num_reqs
|
|
self.max_model_len = max_model_len
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
|
self.max_num_cached_reqs = max_num_cached_reqs
|
|
self.device = device
|
|
self.pin_memory = pin_memory
|
|
self.vocab_size = vocab_size
|
|
self.is_spec_decode = is_spec_decode
|
|
self.pooling_params = None
|
|
self.block_sizes = block_sizes
|
|
self.num_prompt_logprobs: dict[int, int] = {}
|
|
|
|
self.req_id_to_index: dict[str, int] = {}
|
|
self.index_to_req_id: dict[int, str] = {}
|
|
self.free_indices = list(range(max_num_cached_reqs))
|
|
|
|
# Request states.
|
|
self.req_data: dict[int, RequestData] = {}
|
|
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
|
# initialize it on CPU memory.
|
|
self.token_ids = self._make_param(
|
|
num_cols=self.max_model_len,
|
|
dtype=torch.int32,
|
|
cpu_only=True,
|
|
)
|
|
self.num_prompt_tokens = self._make_param(torch.int32)
|
|
self.num_tokens = self._make_param(torch.int32)
|
|
self.num_computed_tokens = self._make_param(torch.int32)
|
|
|
|
self.sampling_states = SamplingStates(
|
|
max_num_reqs=max_num_reqs,
|
|
max_model_len=max_model_len,
|
|
max_num_cached_reqs=max_num_cached_reqs,
|
|
device=device,
|
|
)
|
|
|
|
def _make_param(
|
|
self,
|
|
dtype: torch.dtype,
|
|
num_cols: int = 1,
|
|
cpu_only: bool = False,
|
|
) -> Param:
|
|
return Param(
|
|
self.max_num_cached_reqs,
|
|
num_cols,
|
|
self.max_num_reqs if not cpu_only else 0,
|
|
dtype,
|
|
self.device,
|
|
self.pin_memory,
|
|
is_scalar=num_cols == 1,
|
|
)
|
|
|
|
@property
|
|
def num_cached_reqs(self) -> int:
|
|
return len(self.req_id_to_index)
|
|
|
|
def add_request(
|
|
self,
|
|
req_id: str,
|
|
prompt_token_ids: list[int],
|
|
num_computed_tokens: int,
|
|
sampling_params: SamplingParams,
|
|
) -> None:
|
|
assert len(self.free_indices) > 0, "No free space in GPU worker states"
|
|
req_idx = self.free_indices.pop()
|
|
self.req_id_to_index[req_id] = req_idx
|
|
self.index_to_req_id[req_idx] = req_id
|
|
|
|
prompt_len = len(prompt_token_ids)
|
|
self.num_prompt_tokens.np[req_idx] = prompt_len
|
|
self.num_tokens.np[req_idx] = prompt_len
|
|
self.token_ids.np[req_idx, :prompt_len] = prompt_token_ids
|
|
self.num_computed_tokens.np[req_idx] = num_computed_tokens
|
|
|
|
self.temperature.np[req_idx] = sampling_params.temperature
|
|
if sampling_params.sampling_type == SamplingType.GREEDY:
|
|
# NOTE: Be careful about division by zero.
|
|
self.greedy_reqs.add(req_id)
|
|
elif sampling_params.sampling_type == SamplingType.RANDOM:
|
|
self.random_reqs.add(req_id)
|
|
|
|
self.top_p.np[req_idx] = sampling_params.top_p
|
|
if sampling_params.top_p < 1.0:
|
|
self.top_p_reqs.add(req_id)
|
|
|
|
top_k = sampling_params.top_k
|
|
if 0 < top_k < self.vocab_size:
|
|
self.top_k_reqs.add(req_id)
|
|
else:
|
|
top_k = self.vocab_size
|
|
self.top_k.np[req_idx] = top_k
|
|
|
|
self.frequency_penalties.np[
|
|
req_idx] = sampling_params.frequency_penalty
|
|
if sampling_params.frequency_penalty != 0.0:
|
|
self.frequency_penalties_reqs.add(req_id)
|
|
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
|
|
if sampling_params.presence_penalty != 0.0:
|
|
self.presence_penalties_reqs.add(req_id)
|
|
self.repetition_penalties.np[
|
|
req_idx] = sampling_params.repetition_penalty
|
|
if sampling_params.repetition_penalty != 1.0:
|
|
self.repetition_penalties_reqs.add(req_id)
|
|
|
|
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
generator = torch.Generator(device=self.device)
|
|
generator.manual_seed(sampling_params.seed)
|
|
self.generators[req_idx] = generator
|
|
|
|
def append_token_ids(
|
|
self,
|
|
req_idx: int,
|
|
token_ids: Union[list[int], np.ndarray],
|
|
) -> None:
|
|
start_idx = self.num_tokens.np[req_idx]
|
|
end_idx = start_idx + len(token_ids)
|
|
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
|
|
self.num_tokens.np[req_idx] = end_idx
|
|
|
|
def remove_request(self, req_id: str) -> None:
|
|
req_idx = self.req_id_to_index.pop(req_id, None)
|
|
if req_idx is None:
|
|
# Request not found.
|
|
return
|
|
self.index_to_req_id.pop(req_idx, None)
|
|
self.free_indices.append(req_idx)
|
|
|
|
self.greedy_reqs.discard(req_id)
|
|
self.random_reqs.discard(req_id)
|
|
self.top_p_reqs.discard(req_id)
|
|
self.top_k_reqs.discard(req_id)
|
|
self.frequency_penalties_reqs.discard(req_id)
|
|
self.presence_penalties_reqs.discard(req_id)
|
|
self.repetition_penalties_reqs.discard(req_id)
|
|
self.generators.pop(req_idx, None)
|
|
|
|
def make_sampling_metadata(
|
|
self,
|
|
batch_idx_to_req_idx: torch.Tensor,
|
|
) -> SamplingMetadata:
|
|
batch_size = batch_idx_to_req_idx.shape[0]
|
|
if self.top_p_reqs:
|
|
top_p_buffer = self.top_p.mirror_to_gpu()
|
|
top_p = self.top_p.gpu
|
|
else:
|
|
top_p_buffer = self.top_p.gpu_buffer
|
|
top_p = None
|
|
if self.top_k_reqs:
|
|
top_k_buffer = self.top_k.mirror_to_gpu()
|
|
top_k = self.top_k.gpu
|
|
else:
|
|
top_k_buffer = self.top_k.gpu_buffer
|
|
top_k = None
|
|
# TODO(woosuk): Use UVA to optimize CPU -> GPU copy.
|
|
_make_sampling_metadata_kernel[(batch_size, )](
|
|
batch_idx_to_req_idx,
|
|
self.temperature.mirror_to_gpu(),
|
|
self.temperature.gpu,
|
|
top_p_buffer,
|
|
self.top_p.gpu,
|
|
top_k_buffer,
|
|
self.top_k.gpu,
|
|
self.frequency_penalties.mirror_to_gpu(),
|
|
self.frequency_penalties.gpu,
|
|
self.presence_penalties.mirror_to_gpu(),
|
|
self.presence_penalties.gpu,
|
|
self.repetition_penalties.mirror_to_gpu(),
|
|
self.repetition_penalties.gpu,
|
|
num_warps=1,
|
|
num_stages=1,
|
|
)
|
|
no_penalties = not (self.frequency_penalties_reqs
|
|
or self.presence_penalties_reqs
|
|
or self.repetition_penalties_reqs)
|
|
return SamplingMetadata(
|
|
temperature=self.temperature.gpu[:batch_size],
|
|
all_greedy=not self.random_reqs,
|
|
all_random=not self.greedy_reqs,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
|
|
presence_penalties=self.presence_penalties.gpu[:batch_size],
|
|
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
|
|
no_penalties=no_penalties,
|
|
# TODO
|
|
generators={},
|
|
token_ids=None,
|
|
num_tokens=None,
|
|
num_prompt_tokens=None,
|
|
max_num_logprobs=None,
|
|
allowed_token_ids_mask=None,
|
|
bad_words_token_ids={},
|
|
logitsprocs=None,
|
|
)
|
|
|
|
def make_spec_decode_metadata(
|
|
self,
|
|
query_start_loc: torch.Tensor,
|
|
cu_num_draft_tokens: torch.Tensor,
|
|
cu_num_draft_tokens_np: np.ndarray,
|
|
input_ids: torch.Tensor,
|
|
) -> SpecDecodeMetadata:
|
|
batch_size = query_start_loc.shape[0] - 1
|
|
total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1]
|
|
logits_indices = torch.empty(total_num_draft_tokens + batch_size,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
target_logits_indices = torch.empty(total_num_draft_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
bonus_logits_indices = torch.empty(batch_size,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
_prepare_spec_decode_kernel[(batch_size, )](
|
|
query_start_loc,
|
|
cu_num_draft_tokens,
|
|
logits_indices,
|
|
target_logits_indices,
|
|
bonus_logits_indices,
|
|
BLOCK_SIZE=triton.next_power_of_2(_MAX_SPEC_LEN + 1),
|
|
)
|
|
|
|
draft_token_ids = input_ids[logits_indices]
|
|
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
|
return SpecDecodeMetadata(
|
|
draft_token_ids=draft_token_ids,
|
|
num_draft_tokens=cu_num_draft_tokens_np.tolist(),
|
|
cu_num_draft_tokens=cu_num_draft_tokens,
|
|
target_logits_indices=target_logits_indices,
|
|
bonus_logits_indices=bonus_logits_indices,
|
|
logits_indices=logits_indices,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def _make_sampling_metadata_kernel(
|
|
batch_idx_to_req_idx, # [batch_size]
|
|
src_temperature,
|
|
dst_temperature,
|
|
src_top_p,
|
|
dst_top_p,
|
|
src_top_k,
|
|
dst_top_k,
|
|
src_frequency_penalties,
|
|
dst_frequency_penalties,
|
|
src_presence_penalties,
|
|
dst_presence_penalties,
|
|
src_repetition_penalties,
|
|
dst_repetition_penalties,
|
|
):
|
|
batch_idx = tl.program_id(0)
|
|
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
|
|
|
temperature = tl.load(src_temperature + req_idx)
|
|
tl.store(dst_temperature + batch_idx, temperature)
|
|
|
|
top_p = tl.load(src_top_p + req_idx)
|
|
tl.store(dst_top_p + batch_idx, top_p)
|
|
|
|
top_k = tl.load(src_top_k + req_idx)
|
|
tl.store(dst_top_k + batch_idx, top_k)
|
|
|
|
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
|
|
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
|
|
|
|
presence_penalties = tl.load(src_presence_penalties + req_idx)
|
|
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
|
|
|
|
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
|
|
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
|
|
|
|
|
|
@triton.jit
|
|
def _prepare_spec_decode_kernel(
|
|
query_start_loc, # [B + 1]
|
|
cu_num_draft_tokens, # [B]
|
|
logits_indices, # [N + B]
|
|
target_logits_indices, # [N]
|
|
bonus_logits_indices, # [B]
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
batch_idx = tl.program_id(0)
|
|
|
|
if batch_idx == 0:
|
|
draft_start_idx = 0
|
|
else:
|
|
draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1)
|
|
draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx)
|
|
draft_len = draft_end_idx - draft_start_idx
|
|
sample_len = draft_len + 1
|
|
|
|
q_end_idx = tl.load(query_start_loc + batch_idx + 1)
|
|
|
|
sample_start_idx = draft_start_idx + batch_idx
|
|
sample_end_idx = sample_start_idx + sample_len
|
|
offset = tl.arange(0, BLOCK_SIZE)
|
|
tl.store(logits_indices + sample_start_idx + offset,
|
|
q_end_idx - sample_len + offset,
|
|
mask=offset < sample_len)
|
|
tl.store(target_logits_indices + draft_start_idx + offset,
|
|
sample_start_idx + offset,
|
|
mask=offset < draft_len)
|
|
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)
|