mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-20 06:47:04 +08:00
wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
c472982746
commit
79e5eb3643
@ -6,6 +6,8 @@ from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -17,6 +19,8 @@ from vllm.utils import cdiv, get_cuda_view_from_cpu_tensor, is_uva_available
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
@ -27,6 +31,7 @@ class CachedRequestState:
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
|
||||
# M-RoPE (only for Qwen2/2.5-VL)
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
@ -46,21 +51,31 @@ class PerRequestAttribute:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
N: int,
|
||||
M: int,
|
||||
K: int,
|
||||
num_rows_cpu: int,
|
||||
num_cols: int,
|
||||
num_rows_gpu: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
is_scalar: bool = False,
|
||||
):
|
||||
assert is_uva_available(), "UVA is not available."
|
||||
self.cpu_tensor = torch.zeros(N,
|
||||
M,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.np = self.cpu_tensor.numpy()
|
||||
self.uva_tensor = get_cuda_view_from_cpu_tensor(self.cpu_tensor)
|
||||
self.gpu_tensor = torch.zeros(K, M, dtype=dtype, device=device)
|
||||
self.cpu = torch.zeros(num_rows_cpu,
|
||||
num_cols,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.np = self.cpu.numpy()
|
||||
self.uva = get_cuda_view_from_cpu_tensor(self.cpu)
|
||||
self.gpu = torch.zeros(num_rows_gpu,
|
||||
num_cols,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
if is_scalar:
|
||||
assert num_cols == 1
|
||||
self.cpu.squeeze_(1)
|
||||
self.np = self.cpu.numpy()
|
||||
self.uva.squeeze_(1)
|
||||
self.gpu.squeeze_(1)
|
||||
|
||||
|
||||
class InputBatch:
|
||||
@ -87,15 +102,19 @@ class InputBatch:
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.pooling_params = None
|
||||
self.block_sizes = block_sizes
|
||||
self.num_prompt_logprobs = {}
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_cached_reqs))
|
||||
self._add_scalar_attr("idx_mapping", torch.int32)
|
||||
|
||||
# Request states.
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
# Find a way to reduce the memory usage.
|
||||
self._add_vector_attr("token_ids", self.max_model_len, torch.int32)
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32)
|
||||
self._add_scalar_attr("num_prompt_tokens", torch.int32)
|
||||
self._add_scalar_attr("num_tokens", torch.int32)
|
||||
self._add_scalar_attr("num_computed_tokens", torch.int32)
|
||||
@ -119,20 +138,7 @@ class InputBatch:
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
# Block table(s).
|
||||
self.block_tables = []
|
||||
self.num_blocks = []
|
||||
for block_size in block_sizes:
|
||||
max_num_blocks = cdiv(max_model_len, block_size)
|
||||
block_table = PerRequestAttribute(self.max_num_cached_reqs,
|
||||
max_num_blocks,
|
||||
self.max_num_reqs, torch.int32,
|
||||
self.device)
|
||||
self.block_tables.append(block_table)
|
||||
num_blocks = PerRequestAttribute(self.max_num_cached_reqs, 1,
|
||||
self.max_num_reqs, torch.int32,
|
||||
self.device)
|
||||
self.num_blocks.append(num_blocks)
|
||||
self.num_block_tables = len(block_sizes)
|
||||
self._init_block_tables()
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@ -146,11 +152,10 @@ class InputBatch:
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
self.token_ids.np[req_idx, :num_prompt_tokens] = prompt_token_ids
|
||||
self.num_prompt_tokens.np[req_idx] = num_prompt_tokens
|
||||
self.num_tokens.np[req_idx] = num_prompt_tokens
|
||||
self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids)
|
||||
self.num_computed_tokens.np[req_idx] = num_computed_tokens
|
||||
self.append_token_ids(req_idx, prompt_token_ids)
|
||||
self.append_block_ids(req_idx, block_ids, overwrite=True)
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
@ -171,56 +176,48 @@ class InputBatch:
|
||||
self.top_k.np[req_idx] = top_k
|
||||
|
||||
self.frequency_penalties.np[
|
||||
req_idx] = sampling_params.frequency_penalties
|
||||
if sampling_params.frequency_penalties != 0.0:
|
||||
req_idx] = sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties.np[
|
||||
req_idx] = sampling_params.presence_penalties
|
||||
if sampling_params.presence_penalties != 0.0:
|
||||
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties.np[
|
||||
req_idx] = sampling_params.repetition_penalties
|
||||
if sampling_params.repetition_penalties != 1.0:
|
||||
req_idx] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
self.generators[req_idx] = generator
|
||||
|
||||
for i in range(self.num_block_tables):
|
||||
self.block_tables[i].np[req_idx, :len(block_ids[i])] = block_ids[i]
|
||||
self.num_blocks[i].np[req_idx] = len(block_ids[i])
|
||||
|
||||
def append_token_ids(self, req_id: str, token_ids: list[int]) -> None:
|
||||
req_idx = self.req_id_to_index.get(req_id)
|
||||
assert req_idx is not None
|
||||
def append_token_ids(self, req_idx: int, token_ids: list[int]) -> None:
|
||||
start_idx = self.num_tokens.np[req_idx]
|
||||
end_idx = start_idx + len(token_ids)
|
||||
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
|
||||
self.num_tokens.np[req_idx] = end_idx
|
||||
|
||||
# TODO(woosuk): Further vectorize this to minimize overheads.
|
||||
def append_block_ids(
|
||||
self,
|
||||
req_id: str,
|
||||
req_idx: int,
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
overwrite: bool,
|
||||
) -> None:
|
||||
req_idx = self.req_id_to_index.get(req_id)
|
||||
assert req_idx is not None
|
||||
for i in range(self.num_block_tables):
|
||||
block_table = self.block_tables[i]
|
||||
num_blocks = self.num_blocks[i]
|
||||
num_new_blocks = len(new_block_ids[i])
|
||||
if overwrite:
|
||||
# Replace the existing block IDs with the new ones.
|
||||
# This happens when the request is resumed from preemption.
|
||||
block_table.np[
|
||||
req_idx, :len(new_block_ids[i])] = new_block_ids[i]
|
||||
num_blocks.np[req_idx] = len(new_block_ids[i])
|
||||
block_table.np[req_idx, :num_new_blocks] = new_block_ids[i]
|
||||
num_blocks.np[req_idx] = num_new_blocks
|
||||
else:
|
||||
# Append the new block IDs to the existing ones (common case).
|
||||
start_idx = num_blocks.np[req_idx]
|
||||
end_idx = start_idx + len(new_block_ids[i])
|
||||
end_idx = start_idx + num_new_blocks
|
||||
block_table.np[req_idx, start_idx:end_idx] = new_block_ids[i]
|
||||
num_blocks.np[req_idx] = end_idx
|
||||
|
||||
@ -241,52 +238,50 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_idx, None)
|
||||
|
||||
def make_block_table(self, req_idx: int) -> tuple[torch.Tensor, ...]:
|
||||
pass
|
||||
def get_index_mapping(self, idx_mapping: list[int]) -> torch.Tensor:
|
||||
num_reqs = len(idx_mapping)
|
||||
self.idx_mapping.np[:num_reqs] = idx_mapping
|
||||
return self.idx_mapping.gpu[:num_reqs].copy_(
|
||||
self.idx_mapping.uva[:num_reqs], non_blocking=True)
|
||||
|
||||
def make_sampling_metadata(self,
|
||||
req_indices: list[int]) -> SamplingMetadata:
|
||||
batch_size = len(req_indices)
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
batch_idx_to_req_idx: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
batch_size = batch_idx_to_req_idx.shape[0]
|
||||
_make_sampling_metadata_kernel[(batch_size, )](
|
||||
req_indices,
|
||||
self.temperature.uva_tensor,
|
||||
self.temperature.gpu_tensor,
|
||||
self.top_p.uva_tensor,
|
||||
self.top_p.gpu_tensor,
|
||||
self.top_k.uva_tensor,
|
||||
self.top_k.gpu_tensor,
|
||||
self.frequency_penalties.uva_tensor,
|
||||
self.frequency_penalties.gpu_tensor,
|
||||
self.presence_penalties.uva_tensor,
|
||||
self.presence_penalties.gpu_tensor,
|
||||
self.repetition_penalties.uva_tensor,
|
||||
self.repetition_penalties.gpu_tensor,
|
||||
batch_idx_to_req_idx,
|
||||
self.temperature.uva,
|
||||
self.temperature.gpu,
|
||||
self.top_p.uva,
|
||||
self.top_p.gpu,
|
||||
self.top_k.uva,
|
||||
self.top_k.gpu,
|
||||
self.frequency_penalties.uva,
|
||||
self.frequency_penalties.gpu,
|
||||
self.presence_penalties.uva,
|
||||
self.presence_penalties.gpu,
|
||||
self.repetition_penalties.uva,
|
||||
self.repetition_penalties.gpu,
|
||||
num_warps=1,
|
||||
num_stages=1,
|
||||
)
|
||||
generators = {}
|
||||
if self.generators:
|
||||
for i, req_idx in enumerate(req_indices):
|
||||
generator = self.generators.get(req_idx)
|
||||
if generator is not None:
|
||||
generators[i] = generator
|
||||
no_penalties = not (self.frequency_penalties_reqs
|
||||
or self.presence_penalties_reqs
|
||||
or self.repetition_penalties_reqs)
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature.gpu_tensor[:batch_size],
|
||||
temperature=self.temperature.gpu[:batch_size],
|
||||
all_greedy=not self.random_reqs,
|
||||
all_random=not self.greedy_reqs,
|
||||
top_p=self.top_p.gpu_tensor[:batch_size],
|
||||
top_k=self.top_k.gpu_tensor[:batch_size],
|
||||
frequency_penalties=self.frequency_penalties.
|
||||
gpu_tensor[:batch_size],
|
||||
presence_penalties=self.presence_penalties.gpu_tensor[:batch_size],
|
||||
repetition_penalties=self.repetition_penalties.
|
||||
gpu_tensor[:batch_size],
|
||||
top_p=self.top_p.gpu[:batch_size],
|
||||
top_k=self.top_k.gpu[:batch_size],
|
||||
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
|
||||
presence_penalties=self.presence_penalties.gpu[:batch_size],
|
||||
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
|
||||
no_penalties=no_penalties,
|
||||
generators=generators,
|
||||
token_ids=self.token_ids.gpu_tensor[:batch_size],
|
||||
# TODO
|
||||
generators={},
|
||||
token_ids=self.token_ids.gpu[:batch_size],
|
||||
max_num_logprobs=None,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
@ -297,22 +292,113 @@ class InputBatch:
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
|
||||
attr = PerRequestAttribute(self.max_num_cached_reqs,
|
||||
1,
|
||||
self.max_num_reqs,
|
||||
dtype,
|
||||
self.device,
|
||||
is_scalar=True)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def _add_vector_attr(self, name: str, max_len: int, dtype: torch.dtype):
|
||||
attr = PerRequestAttribute(self.max_num_cached_reqs, max_len,
|
||||
self.max_num_reqs, dtype, self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
|
||||
self._add_vector_attr(name, max_len=1, dtype=dtype)
|
||||
def _add_vector_attr_cpu(self, name: str, max_len: int,
|
||||
dtype: torch.dtype):
|
||||
attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, 0, dtype,
|
||||
self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def _init_block_tables(self):
|
||||
self.num_block_tables = len(self.block_sizes)
|
||||
self.block_tables = []
|
||||
self.num_blocks = []
|
||||
self.slot_mappings: list[torch.Tensor] = []
|
||||
for i in range(self.num_block_tables):
|
||||
max_num_blocks = cdiv(self.max_model_len, self.block_sizes[i])
|
||||
block_table = PerRequestAttribute(self.max_num_cached_reqs,
|
||||
max_num_blocks,
|
||||
self.max_num_reqs, torch.int32,
|
||||
self.device)
|
||||
self.block_tables.append(block_table)
|
||||
num_blocks = PerRequestAttribute(self.max_num_cached_reqs,
|
||||
1,
|
||||
self.max_num_reqs,
|
||||
torch.int32,
|
||||
self.device,
|
||||
is_scalar=True)
|
||||
self.num_blocks.append(num_blocks)
|
||||
slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.slot_mappings.append(slot_mapping)
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
def make_ptr_tensor(x: list[torch.Tensor]) -> torch.Tensor:
|
||||
return torch.tensor([t.data_ptr() for t in x],
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
self.uva_block_table_ptrs = make_ptr_tensor(
|
||||
[b.uva for b in self.block_tables])
|
||||
self.gpu_block_table_ptrs = make_ptr_tensor(
|
||||
[b.gpu for b in self.block_tables])
|
||||
self.uva_num_blocks_ptrs = make_ptr_tensor(
|
||||
[n.uva for n in self.num_blocks])
|
||||
self.gpu_num_blocks_ptrs = make_ptr_tensor(
|
||||
[n.gpu for n in self.num_blocks])
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.gpu.shape[1] for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.block_sizes_tensor = torch.tensor(self.block_sizes,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping_ptrs = make_ptr_tensor(self.slot_mappings)
|
||||
|
||||
def make_block_tables(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size = idx_mapping.shape[0]
|
||||
_make_block_tables_kernel[(batch_size, self.num_block_tables)](
|
||||
idx_mapping,
|
||||
self.uva_block_table_ptrs,
|
||||
self.gpu_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.uva_num_blocks_ptrs,
|
||||
self.gpu_num_blocks_ptrs,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(b.gpu[:batch_size] for b in self.block_tables)
|
||||
|
||||
def make_slot_mappings(
|
||||
self,
|
||||
cu_num_tokens: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_tokens = pos.shape[0]
|
||||
num_reqs = cu_num_tokens.shape[0] - 1
|
||||
_make_slot_mappings_kernel[(num_reqs + 1, self.num_block_tables)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
cu_num_tokens,
|
||||
pos,
|
||||
self.gpu_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mapping_ptrs,
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(x[:num_tokens] for x in self.slot_mappings)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_sampling_metadata_kernel(
|
||||
req_indices, # [batch_size]
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_temperature,
|
||||
dst_temperature,
|
||||
src_top_p,
|
||||
@ -327,52 +413,104 @@ def _make_sampling_metadata_kernel(
|
||||
dst_repetition_penalties,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_index = tl.load(req_indices + batch_idx)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
temperature = tl.load(src_temperature + req_index)
|
||||
tl.store(dst_temperature + req_index, temperature)
|
||||
temperature = tl.load(src_temperature + req_idx)
|
||||
tl.store(dst_temperature + batch_idx, temperature)
|
||||
|
||||
top_p = tl.load(src_top_p + req_index)
|
||||
tl.store(dst_top_p + req_index, top_p)
|
||||
top_p = tl.load(src_top_p + req_idx)
|
||||
tl.store(dst_top_p + batch_idx, top_p)
|
||||
|
||||
top_k = tl.load(src_top_k + req_index)
|
||||
tl.store(dst_top_k + req_index, top_k)
|
||||
top_k = tl.load(src_top_k + req_idx)
|
||||
tl.store(dst_top_k + batch_idx, top_k)
|
||||
|
||||
frequency_penalties = tl.load(src_frequency_penalties + req_index)
|
||||
tl.store(dst_frequency_penalties + req_index, frequency_penalties)
|
||||
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
|
||||
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
|
||||
|
||||
presence_penalties = tl.load(src_presence_penalties + req_index)
|
||||
tl.store(dst_presence_penalties + req_index, presence_penalties)
|
||||
presence_penalties = tl.load(src_presence_penalties + req_idx)
|
||||
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
|
||||
|
||||
repetition_penalties = tl.load(src_repetition_penalties + req_index)
|
||||
tl.store(dst_repetition_penalties + req_index, repetition_penalties)
|
||||
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
|
||||
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_block_table_kernel(
|
||||
req_indices, # [batch_size]
|
||||
src_block_table_ptrs,
|
||||
dst_block_table_ptrs,
|
||||
src_num_blocks_ptrs,
|
||||
dst_num_blocks_ptrs,
|
||||
num_block_tables: tl.constexpr,
|
||||
def _make_block_tables_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_block_table_ptrs, # [num_block_tables]
|
||||
dst_block_table_ptrs, # [num_block_tables]
|
||||
block_table_strides, # [num_block_tables]
|
||||
src_num_blocks_ptrs, # [num_block_tables]
|
||||
dst_num_blocks_ptrs, # [num_block_tables]
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_index = tl.load(req_indices + batch_idx)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
for i in tl.range(num_block_tables):
|
||||
src_num_blocks_ptr = tl.load(src_num_blocks_ptrs + i)
|
||||
dst_num_blocks_ptr = tl.load(dst_num_blocks_ptrs + i)
|
||||
num_blocks = tl.load(src_num_blocks_ptr + req_index)
|
||||
tl.store(dst_num_blocks_ptr + req_index, num_blocks)
|
||||
src_num_blocks_ptr = _load_ptr(src_num_blocks_ptrs, group_id, tl.int32)
|
||||
dst_num_blocks_ptr = _load_ptr(dst_num_blocks_ptrs, group_id, tl.int32)
|
||||
num_blocks = tl.load(src_num_blocks_ptr + req_idx)
|
||||
tl.store(dst_num_blocks_ptr + batch_idx, num_blocks)
|
||||
|
||||
src_block_table_ptr = tl.load(src_block_table_ptrs + i)
|
||||
dst_block_table_ptr = tl.load(dst_block_table_ptrs + i)
|
||||
for j in tl.range(num_blocks, BLOCK_SIZE):
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(src_block_table_ptr + j * BLOCK_SIZE + offset,
|
||||
mask=offset < num_blocks)
|
||||
tl.store(dst_block_table_ptr + j * BLOCK_SIZE + offset,
|
||||
block_ids,
|
||||
mask=offset < num_blocks)
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs, group_id, tl.int32)
|
||||
src_row_ptr = src_block_table_ptr + req_idx * stride
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs, group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
|
||||
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
cu_num_tokens, # [num_reqs + 1]
|
||||
pos, # [num_tokens]
|
||||
block_table_ptrs, # [num_block_tables]
|
||||
block_table_strides, # [num_block_tables]
|
||||
page_sizes, # [num_block_tables]
|
||||
slot_mapping_ptrs, # [num_block_tables]
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
num_reqs = tl.num_programs(0)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
slot_mapping_ptr = _load_ptr(slot_mapping_ptrs, group_id, tl.int64)
|
||||
|
||||
if req_idx == num_reqs - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = num_tokens + i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset,
|
||||
PAD_ID,
|
||||
mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs, group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
page_size = tl.load(page_sizes + group_id)
|
||||
|
||||
start_idx = tl.load(cu_num_tokens + req_idx)
|
||||
end_idx = tl.load(cu_num_tokens + req_idx + 1)
|
||||
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offset = start_idx + i + tl.arange(0, BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // page_size
|
||||
block_numbers = tl.load(block_table_ptr +
|
||||
req_idx * block_table_stride + block_indices)
|
||||
slot_ids = block_numbers * page_size + positions % page_size
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _load_ptr(base, offset, elem_dtype):
|
||||
ptr = tl.load(base + offset)
|
||||
return tl.cast(ptr, tl.pointer_type(elem_dtype))
|
||||
|
||||
@ -47,7 +47,6 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
@ -215,6 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_num_cached_reqs=2 * self.max_num_reqs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
@ -289,6 +289,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
|
||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||
# Keep in int64 to avoid overflow with long context
|
||||
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
||||
self.max_model_len,
|
||||
self.max_num_tokens),
|
||||
dtype=np.int64)
|
||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||
# a faster version of creating a new tensor every time. Thus, we should
|
||||
# not make any assumptions about the values in these tensors.
|
||||
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.positions_np = self.positions_cpu.numpy()
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
# Layer pairings for cross-layer KV sharing.
|
||||
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
# means this layer will perform attention using the keys and values
|
||||
@ -318,6 +347,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
torch.Tensor]] = None
|
||||
|
||||
def _init_model_kwargs(self, num_tokens: int):
|
||||
return {}
|
||||
model_kwargs = dict[str, Any]()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
|
||||
@ -410,20 +440,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for req_id in unscheduled_req_ids:
|
||||
self.input_batch.remove_request(req_id)
|
||||
|
||||
reqs_to_add: list[CachedRequestState] = []
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
pooling_params = new_req_data.pooling_params
|
||||
|
||||
if sampling_params and \
|
||||
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
if pooling_params:
|
||||
task = pooling_params.task
|
||||
assert task is not None, "You did not set `task` in the API"
|
||||
@ -434,141 +456,81 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
req_state = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_kwargs=new_req_data.mm_kwargs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
self.requests[req_id] = req_state
|
||||
self.input_batch.add_request(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
block_ids=new_req_data.block_ids,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
second_per_grid_ts = []
|
||||
audio_feature_lengths = []
|
||||
use_audio_in_video = False
|
||||
for mm_item in req_state.mm_kwargs:
|
||||
mm_input = mm_item.get_data()
|
||||
if (t := mm_input.get("image_grid_thw")) is not None:
|
||||
image_grid_thw.append(t.tolist())
|
||||
if (t := mm_input.get("video_grid_thw")) is not None:
|
||||
video_grid_thw.append(t.tolist())
|
||||
if (t := mm_input.get("second_per_grid_ts")) is not None:
|
||||
second_per_grid_ts.append(t)
|
||||
if (t :=
|
||||
mm_input.get("audio_feature_lengths")) is not None:
|
||||
audio_feature_lengths.append(t)
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
reqs_to_add.append(req_state)
|
||||
self._init_mrope_states(req_state)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
is_last_rank = get_pp_group().is_last_rank
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(req_data.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
|
||||
# Update the cached states.
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update input batch.
|
||||
if not is_last_rank:
|
||||
# When using PP, the scheduler sends the sampled tokens back,
|
||||
# because there's no direct communication between the first-
|
||||
# stage worker and the last-stage worker.
|
||||
new_token_ids = req_data.new_token_ids[i]
|
||||
# Add the sampled token(s) from the previous step (if any).
|
||||
# This doesn't include "unverified" tokens like spec tokens.
|
||||
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
||||
req_state.num_tokens)
|
||||
if num_new_tokens == 1:
|
||||
# Avoid slicing list in most common case.
|
||||
req_state.output_token_ids.append(new_token_ids[-1])
|
||||
elif num_new_tokens > 0:
|
||||
req_state.output_token_ids.extend(
|
||||
new_token_ids[-num_new_tokens:])
|
||||
self.input_batch.append_token_ids(req_index, new_token_ids)
|
||||
|
||||
# Update the block IDs.
|
||||
if not resumed_from_preemption:
|
||||
if new_block_ids is not None:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for block_ids, new_ids in zip(req_state.block_ids,
|
||||
new_block_ids):
|
||||
block_ids.extend(new_ids)
|
||||
else:
|
||||
assert new_block_ids is not None
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = new_block_ids
|
||||
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
# scheduled in the previous step and needs to be added again.
|
||||
reqs_to_add.append(req_state)
|
||||
continue
|
||||
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
num_computed_tokens)
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
if new_block_ids is not None:
|
||||
self.input_batch.block_table.append_row(
|
||||
new_block_ids, req_index)
|
||||
|
||||
# For the last rank, we don't need to update the token_ids_cpu
|
||||
# because the sampled tokens are already cached.
|
||||
if not is_last_rank:
|
||||
# Add new_token_ids to token_ids_cpu.
|
||||
start_token_index = num_computed_tokens
|
||||
end_token_index = num_computed_tokens + len(new_token_ids)
|
||||
self.input_batch.token_ids_cpu[
|
||||
# If the request is resumed from preemption, we need to
|
||||
# overwrite the existing block IDs.
|
||||
self.input_batch.append_block_ids(
|
||||
req_index,
|
||||
start_token_index:end_token_index] = new_token_ids
|
||||
self.input_batch.num_tokens_no_spec[
|
||||
req_index] = end_token_index
|
||||
self.input_batch.num_tokens[req_index] = end_token_index
|
||||
new_block_ids,
|
||||
overwrite=req_data.resumed_from_preemption[i],
|
||||
)
|
||||
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||||
if spec_token_ids:
|
||||
num_spec_tokens = len(spec_token_ids)
|
||||
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
||||
end_token_index = start_index + num_spec_tokens
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index, start_index:end_token_index] = spec_token_ids
|
||||
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
||||
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
||||
self.input_batch.num_computed_tokens.np[req_index] = (
|
||||
req_data.num_computed_tokens[i])
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
for request in reqs_to_add:
|
||||
self.input_batch.add_request(request)
|
||||
def _init_mrope_states(self, req_state: CachedRequestState) -> None:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
second_per_grid_ts = []
|
||||
audio_feature_lengths = []
|
||||
use_audio_in_video = False
|
||||
for mm_item in req_state.mm_kwargs:
|
||||
mm_input = mm_item.get_data()
|
||||
if (t := mm_input.get("image_grid_thw")) is not None:
|
||||
image_grid_thw.append(t.tolist())
|
||||
if (t := mm_input.get("video_grid_thw")) is not None:
|
||||
video_grid_thw.append(t.tolist())
|
||||
if (t := mm_input.get("second_per_grid_ts")) is not None:
|
||||
second_per_grid_ts.append(t)
|
||||
if (t := mm_input.get("audio_feature_lengths")) is not None:
|
||||
audio_feature_lengths.append(t)
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
def _extract_mm_kwargs(
|
||||
self,
|
||||
@ -637,12 +599,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# FIXME
|
||||
# batch_idx -> req_id
|
||||
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
|
||||
# req_id -> batch_idx
|
||||
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping = [
|
||||
self.input_batch.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping)
|
||||
num_reqs = len(req_ids)
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit_block_table(num_reqs)
|
||||
block_tables = self.input_batch.make_block_tables(idx_mapping_tensor)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
@ -659,7 +633,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
np.add(self.input_batch.num_computed_tokens.np[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
@ -673,21 +647,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
# where M is the max_model_len.
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
req_indices * self.input_batch.token_ids.np.shape[1])
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
torch.index_select(self.input_batch.token_ids.cpu.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
@ -698,7 +667,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
self.input_batch.num_computed_tokens.np[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
self.seq_lens_np[num_reqs:].fill(0)
|
||||
@ -737,8 +706,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
batch_idx = req_id_to_batch_idx[req_id]
|
||||
num_draft_tokens[batch_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
@ -788,11 +757,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
per_layer_metadata[layer_name]
|
||||
attn_metadata[layer_name] = encoder_attn_metadata
|
||||
|
||||
slot_mappings = self.input_batch.make_slot_mappings(
|
||||
query_start_loc,
|
||||
self.positions[:total_num_scheduled_tokens],
|
||||
)
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
self.input_batch.num_computed_tokens.cpu[:num_reqs])
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
@ -800,14 +773,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
|
||||
blk_table_tensor = block_tables[kv_cache_group_id]
|
||||
slot_mapping = slot_mappings[kv_cache_group_id]
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
@ -948,7 +915,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
common_prefix_len = min(
|
||||
common_prefix_len,
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
||||
self.input_batch.num_computed_tokens.np[:num_reqs].min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
|
||||
kv_cache_spec.block_size)
|
||||
@ -1454,6 +1421,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len) = (self._prepare_inputs(scheduler_output))
|
||||
|
||||
# FIXME
|
||||
# batch_idx -> req_id
|
||||
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
|
||||
# req_id -> batch_idx
|
||||
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping = [
|
||||
self.input_batch.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping)
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
@ -1593,7 +1572,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
idx_mapping_tensor)
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
@ -1629,24 +1609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token for partial prefills.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
# Record the index of the request that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning.
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
@ -1668,37 +1630,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
sampled_token_ids, self.input_batch.vocab_size)
|
||||
# # Mask out the sampled tokens that should not be sampled.
|
||||
# for i in discard_sampled_tokens_req_indices:
|
||||
# valid_sampled_token_ids[i].clear()
|
||||
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
req_ids = self.input_batch.req_ids
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
assert end_idx <= self.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.max_model_len}")
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
req_id = req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
self.input_batch.append_token_ids(req_idx, sampled_ids)
|
||||
|
||||
if self.speculative_config:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
@ -1716,8 +1661,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.eplb_step()
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_batch_idx,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
@ -2389,14 +2334,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=dummy_tensors(0.1),
|
||||
presence_penalties=dummy_tensors(0.1),
|
||||
repetition_penalties=dummy_tensors(0.1),
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessors(),
|
||||
token_ids=None,
|
||||
)
|
||||
try:
|
||||
sampler_output = self.sampler(logits=logits,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user