mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 00:17:03 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
d30c0d50a6
commit
4be2c66e37
@ -5,9 +5,11 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numba
|
||||
import numba.types as types
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import types
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
@ -161,3 +163,49 @@ def prepare_inputs(
|
||||
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _combine_last_token_ids_kernel(
|
||||
input_ids_ptr,
|
||||
idx_mapping_ptr,
|
||||
last_token_ids_ptr,
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
num_tokens_ptr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||
num_tokens = tl.load(num_tokens_ptr + req_state_idx)
|
||||
if seq_len < num_tokens:
|
||||
# Chunked prefilling.
|
||||
return
|
||||
|
||||
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
|
||||
if last_token_id == -1:
|
||||
return
|
||||
|
||||
end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
tl.store(input_ids_ptr + end - 1, last_token_id)
|
||||
|
||||
|
||||
def combine_last_token_ids(
|
||||
input_ids: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
last_token_ids: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_tokens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = seq_lens.shape[0]
|
||||
_combine_last_token_ids_kernel[(num_reqs, )](
|
||||
input_ids,
|
||||
idx_mapping,
|
||||
last_token_ids,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
num_tokens,
|
||||
)
|
||||
return input_ids
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
|
||||
evenly_split)
|
||||
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
|
||||
combine_last_token_ids,
|
||||
prepare_inputs)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
@ -158,8 +159,8 @@ class GPUModelRunner:
|
||||
num_tokens=num_tokens,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids[:num_tokens],
|
||||
positions=input_batch.positions[:num_tokens],
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@ -205,7 +206,7 @@ class GPUModelRunner:
|
||||
[] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
overwrite: list[bool] = []
|
||||
|
||||
# Add new requests to the cached states.
|
||||
# Add new requests.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
self.req_states.add_request(
|
||||
@ -223,7 +224,7 @@ class GPUModelRunner:
|
||||
new_block_ids[i].extend(block_ids)
|
||||
overwrite.append(True)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
# Add new blocks for the existing requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
@ -237,9 +238,6 @@ class GPUModelRunner:
|
||||
new_block_ids[group_id].extend(block_ids)
|
||||
overwrite.append(False)
|
||||
|
||||
self.req_states.num_computed_tokens[req_index] = (
|
||||
cached_reqs.num_computed_tokens[i])
|
||||
|
||||
if req_indices:
|
||||
self.block_tables.append_block_ids(
|
||||
req_indices=req_indices,
|
||||
@ -275,54 +273,61 @@ class GPUModelRunner:
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
input_ids = self.input_buffers.input_ids
|
||||
positions = self.input_buffers.positions
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
seq_lens = self.input_buffers.seq_lens
|
||||
prepare_inputs(
|
||||
idx_mapping_np,
|
||||
self.req_states.token_ids,
|
||||
self.req_states.prompt_token_ids,
|
||||
self.req_states.num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
input_ids.np,
|
||||
positions.np,
|
||||
query_start_loc.np,
|
||||
seq_lens.np,
|
||||
self.input_buffers.input_ids.np,
|
||||
self.input_buffers.positions.np,
|
||||
self.input_buffers.query_start_loc.np,
|
||||
self.input_buffers.seq_lens.np,
|
||||
)
|
||||
input_ids.copy_to_gpu(num_tokens)
|
||||
positions.copy_to_gpu(num_tokens)
|
||||
|
||||
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
|
||||
self.input_buffers.positions.copy_to_gpu(num_tokens)
|
||||
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
|
||||
# tensors from CPU to GPU, because they may include paddings needed
|
||||
# for full CUDA graph mode.
|
||||
query_start_loc.copy_to_gpu()
|
||||
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
|
||||
self.input_buffers.query_start_loc.copy_to_gpu()
|
||||
self.input_buffers.seq_lens.copy_to_gpu()
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
|
||||
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
|
||||
max_query_len = int(num_scheduled_tokens.max())
|
||||
|
||||
seq_lens.copy_to_gpu()
|
||||
seq_lens_cpu = seq_lens.cpu[:num_reqs]
|
||||
seq_lens_np = seq_lens.np[:num_reqs]
|
||||
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
|
||||
seq_lens_cpu = self.input_buffers.seq_lens.np[:num_reqs]
|
||||
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
|
||||
max_seq_len = int(seq_lens_np.max())
|
||||
seq_lens_gpu = seq_lens.gpu[:num_reqs]
|
||||
|
||||
num_computed_tokens_np = self.req_states.num_computed_tokens[
|
||||
idx_mapping_np]
|
||||
num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np)
|
||||
is_chunked_prefilling = (seq_lens_np
|
||||
< self.req_states.num_tokens[idx_mapping_np])
|
||||
# Some input token ids are directly read from the last sampled tokens.
|
||||
combine_last_token_ids(
|
||||
self.input_buffers.input_ids.gpu,
|
||||
idx_mapping,
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
seq_lens_gpu,
|
||||
self.req_states.num_tokens.copy_to_gpu(),
|
||||
)
|
||||
|
||||
# Slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc_gpu, positions.gpu[:num_tokens])
|
||||
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens])
|
||||
|
||||
num_computed_tokens_cpu = torch.from_numpy(
|
||||
self.req_states.num_computed_tokens[idx_mapping_np])
|
||||
|
||||
# Whether the request is chunked-prefilling or not.
|
||||
is_chunked_prefilling = (
|
||||
seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np])
|
||||
|
||||
# Logits indices to sample next token from.
|
||||
logits_indices = query_start_loc_gpu[1:] - 1
|
||||
num_logits_indices = logits_indices.size(0)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
for i, kv_cache_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
kv_cache_groups = self.kv_cache_config.kv_cache_groups
|
||||
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
@ -352,6 +357,8 @@ class GPUModelRunner:
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
|
||||
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@ -361,8 +368,8 @@ class GPUModelRunner:
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens_after_padding,
|
||||
is_chunked_prefilling=is_chunked_prefilling,
|
||||
input_ids=input_ids.gpu,
|
||||
positions=positions.gpu,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
@ -412,10 +419,20 @@ class GPUModelRunner:
|
||||
sampler_output: SamplerOutput,
|
||||
input_batch: InputBatch,
|
||||
) -> AsyncOutput:
|
||||
# Store the last sampled token ids.
|
||||
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
|
||||
sampler_output.sampled_token_ids)
|
||||
|
||||
# Get the number of sampled tokens.
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
is_chunked_prefilling = input_batch.is_chunked_prefilling
|
||||
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
|
||||
# Increment the number of tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens
|
||||
# Increment the number of computed tokens.
|
||||
self.req_states.num_computed_tokens[idx_mapping_np] += (
|
||||
input_batch.num_scheduled_tokens)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
@ -450,8 +467,8 @@ class GPUModelRunner:
|
||||
num_tokens=num_tokens,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids[:num_tokens],
|
||||
positions=input_batch.positions[:num_tokens],
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
|
||||
sampler_output = self.sample(hidden_states, input_batch)
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numba
|
||||
import numba.types as types
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -76,21 +74,22 @@ class RequestState:
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self.token_ids = np.zeros(
|
||||
# NOTE(woosuk): Strictly speaking, it contains prompt + some output
|
||||
# because of preemption.
|
||||
self.prompt_token_ids = np.zeros(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Last sampled token ids.
|
||||
self.last_token = torch.zeros(
|
||||
# Last sampled tokens.
|
||||
self.last_sampled_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
1,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Sampling parameters.
|
||||
@ -110,6 +109,12 @@ class RequestState:
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
@ -126,11 +131,14 @@ class RequestState:
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
# NOTE(woosuk): Strictly speaking, "prompt_len" here may include
|
||||
# output tokens, if the request is resumed from preemption.
|
||||
prompt_len = len(prompt_token_ids)
|
||||
self.num_tokens[req_idx] = prompt_len
|
||||
self.num_prompt_tokens[req_idx] = prompt_len
|
||||
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.num_tokens.np[req_idx] = prompt_len
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
# TODO(woosuk): Optimize.
|
||||
self.last_sampled_tokens[req_idx].fill_(-1)
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
@ -197,50 +205,6 @@ class RequestState:
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
def append_token_ids(
|
||||
self,
|
||||
req_indices: np.ndarray,
|
||||
sampled_ids: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
) -> None:
|
||||
_append_token_ids(
|
||||
req_indices,
|
||||
sampled_ids,
|
||||
num_sampled_tokens,
|
||||
self.token_ids,
|
||||
self.num_tokens,
|
||||
)
|
||||
|
||||
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:],
|
||||
types.int64[:, :],
|
||||
types.int32[:],
|
||||
types.int32[:, :],
|
||||
types.int32[:],
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def _append_token_ids(
|
||||
req_indices: np.ndarray,
|
||||
sampled_ids: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
token_ids: np.ndarray,
|
||||
num_tokens: np.ndarray,
|
||||
) -> None:
|
||||
num_reqs = num_sampled_tokens.shape[0]
|
||||
for i in range(num_reqs):
|
||||
req_idx = req_indices[i]
|
||||
n = num_sampled_tokens[i]
|
||||
start_idx = num_tokens[req_idx]
|
||||
end_idx = start_idx + n
|
||||
token_ids[req_idx, start_idx:end_idx] = sampled_ids[i, :n]
|
||||
num_tokens[req_idx] = end_idx
|
||||
|
||||
|
||||
class Param:
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user