mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-16 18:17:05 +08:00
minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
0c56069c7e
commit
6283995a6c
@ -3,8 +3,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import types
|
||||
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
@ -33,3 +35,58 @@ class InputBatch:
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
||||
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:], # idx_mapping
|
||||
types.int32[:, :], # token_ids
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
types.int64[:], # positions
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def prepare_inputs(
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
|
||||
cu_num_tokens = 0
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
query_len = num_scheduled_tokens[i]
|
||||
start = num_computed_tokens[req_idx]
|
||||
end = start + query_len
|
||||
seq_lens[i] = end
|
||||
|
||||
start_idx = cu_num_tokens
|
||||
end_idx = start_idx + query_len
|
||||
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
||||
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
|
||||
|
||||
cu_num_tokens = end_idx
|
||||
query_start_loc[i + 1] = cu_num_tokens
|
||||
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
@ -80,8 +80,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.gpu_block_table import BlockTables
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_worker_states import RequestState, prepare_inputs
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs
|
||||
from vllm.v1.worker.gpu_worker_states import RequestState
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
@ -4,12 +4,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from numba import types
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -49,38 +47,76 @@ class RequestData:
|
||||
]
|
||||
|
||||
|
||||
class Param:
|
||||
class SamplingStates:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_rows_cpu: int,
|
||||
num_cols: int,
|
||||
num_rows_gpu: int,
|
||||
dtype: torch.dtype,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_cached_reqs: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
is_scalar: bool = False,
|
||||
):
|
||||
self.cpu = torch.zeros(num_rows_cpu,
|
||||
num_cols,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = torch.zeros(num_rows_gpu,
|
||||
num_cols,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
if is_scalar:
|
||||
self.cpu.squeeze_(1)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu.squeeze_(1)
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_cached_reqs = max_num_cached_reqs
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
|
||||
# TODO(woosuk): Optimize this.
|
||||
self.gpu_buffer = self.cpu.to(device)
|
||||
self.temperature = self._make_param(torch.float32)
|
||||
self.greedy_req_indices: set[int] = set()
|
||||
self.top_p = self._make_param(torch.float32)
|
||||
self.top_p_req_indices: set[int] = set()
|
||||
self.top_k = self._make_param(torch.int32)
|
||||
self.top_k_req_indices: set[int] = set()
|
||||
|
||||
def mirror_to_gpu(self) -> torch.Tensor:
|
||||
return self.gpu_buffer.copy_(self.cpu, non_blocking=True)
|
||||
self.frequency_penalties = self._make_param(torch.float32)
|
||||
self.presence_penalties = self._make_param(torch.float32)
|
||||
self.repetition_penalties = self._make_param(torch.float32)
|
||||
self.penalty_req_indices: set[int] = set()
|
||||
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
def _make_param(self, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.zeros(self.max_num_reqs, dtype=dtype, device=self.device)
|
||||
|
||||
def add_requests(
|
||||
self,
|
||||
req_indices: list[int],
|
||||
sampling_params: list[SamplingParams],
|
||||
) -> None:
|
||||
num_reqs = len(req_indices)
|
||||
for i in range(num_reqs):
|
||||
req_idx = req_indices[i]
|
||||
sampling_param = sampling_params[i]
|
||||
|
||||
temp = sampling_param.temperature
|
||||
if temp == 0.0:
|
||||
self.greedy_req_indices.add(req_idx)
|
||||
|
||||
top_p = sampling_param.top_p
|
||||
if top_p < 1.0:
|
||||
self.top_p_req_indices.add(req_idx)
|
||||
top_k = sampling_param.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_req_indices.add(req_idx)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
|
||||
if sampling_param.frequency_penalty != 0.0 or sampling_param.presence_penalty != 0.0 or sampling_param.repetition_penalty != 1.0:
|
||||
self.penalty_req_indices.add(req_idx)
|
||||
|
||||
if sampling_param.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_param.seed)
|
||||
self.generators[req_idx] = generator
|
||||
|
||||
def remove_request(self, req_idx: int) -> None:
|
||||
self.greedy_req_indices.discard(req_idx)
|
||||
self.top_p_req_indices.discard(req_idx)
|
||||
self.top_k_req_indices.discard(req_idx)
|
||||
self.penalty_req_indices.discard(req_idx)
|
||||
self.generators.pop(req_idx, None)
|
||||
|
||||
|
||||
class RequestState:
|
||||
@ -128,23 +164,12 @@ class RequestState:
|
||||
self.num_tokens = self._make_param(torch.int32)
|
||||
self.num_computed_tokens = self._make_param(torch.int32)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = self._make_param(torch.float32)
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
self.top_p = self._make_param(torch.float32)
|
||||
self.top_p_reqs: set[str] = set()
|
||||
self.top_k = self._make_param(torch.int32)
|
||||
self.top_k_reqs: set[str] = set()
|
||||
self.frequency_penalties = self._make_param(torch.float32)
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
self.presence_penalties = self._make_param(torch.float32)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
self.repetition_penalties = self._make_param(torch.float32)
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_idx -> generator
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
self.sampling_states = SamplingStates(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_cached_reqs=max_num_cached_reqs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _make_param(
|
||||
self,
|
||||
@ -413,58 +438,3 @@ def _prepare_spec_decode_kernel(
|
||||
sample_start_idx + offset,
|
||||
mask=offset < draft_len)
|
||||
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)
|
||||
|
||||
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:], # idx_mapping
|
||||
types.int32[:, :], # token_ids
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
types.int64[:], # positions
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def prepare_inputs(
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
|
||||
cu_num_tokens = 0
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
query_len = num_scheduled_tokens[i]
|
||||
start = num_computed_tokens[req_idx]
|
||||
end = start + query_len
|
||||
seq_lens[i] = end
|
||||
|
||||
start_idx = cu_num_tokens
|
||||
end_idx = start_idx + query_len
|
||||
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
||||
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
|
||||
|
||||
cu_num_tokens = end_idx
|
||||
query_start_loc[i + 1] = cu_num_tokens
|
||||
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user