mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 21:55:49 +08:00
[V1] Get input tokens from scheduler (#13339)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
ce77eb9410
commit
4c21ce9eba
@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
cached_req_data = CachedRequestData(
|
cached_req_data = CachedRequestData(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
resumed_from_preemption=False,
|
resumed_from_preemption=False,
|
||||||
|
new_token_ids=[],
|
||||||
new_block_ids=[],
|
new_block_ids=[],
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=0,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -121,6 +121,8 @@ class Scheduler:
|
|||||||
encoder_budget = self.max_num_encoder_input_tokens
|
encoder_budget = self.max_num_encoder_input_tokens
|
||||||
# Spec decode-related.
|
# Spec decode-related.
|
||||||
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
|
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
|
||||||
|
|
||||||
|
# For logging.
|
||||||
scheduled_timestamp = time.monotonic()
|
scheduled_timestamp = time.monotonic()
|
||||||
|
|
||||||
# First, schedule the RUNNING requests.
|
# First, schedule the RUNNING requests.
|
||||||
@ -187,6 +189,15 @@ class Scheduler:
|
|||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
req_index += 1
|
req_index += 1
|
||||||
|
|
||||||
|
# Speculative decode related.
|
||||||
|
if request.spec_token_ids:
|
||||||
|
num_scheduled_spec_tokens = (num_new_tokens +
|
||||||
|
request.num_computed_tokens -
|
||||||
|
request.num_tokens)
|
||||||
|
if num_scheduled_spec_tokens > 0:
|
||||||
|
scheduled_spec_decode_tokens[request.request_id] = (
|
||||||
|
request.spec_token_ids[:num_scheduled_spec_tokens])
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
if encoder_inputs_to_schedule:
|
if encoder_inputs_to_schedule:
|
||||||
scheduled_encoder_inputs[request.request_id] = (
|
scheduled_encoder_inputs[request.request_id] = (
|
||||||
@ -196,11 +207,6 @@ class Scheduler:
|
|||||||
self.encoder_cache_manager.allocate(request, i)
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
encoder_budget = new_encoder_budget
|
encoder_budget = new_encoder_budget
|
||||||
|
|
||||||
# Speculative decode related.
|
|
||||||
if request.spec_token_ids:
|
|
||||||
scheduled_spec_decode_tokens[
|
|
||||||
request.request_id] = request.spec_token_ids
|
|
||||||
|
|
||||||
# Record the LoRAs in scheduled_running_reqs
|
# Record the LoRAs in scheduled_running_reqs
|
||||||
requested_loras: Set[int] = set()
|
requested_loras: Set[int] = set()
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
@ -324,23 +330,24 @@ class Scheduler:
|
|||||||
# Construct the scheduler output.
|
# Construct the scheduler output.
|
||||||
new_reqs_data = [
|
new_reqs_data = [
|
||||||
NewRequestData.from_request(req,
|
NewRequestData.from_request(req,
|
||||||
req_to_new_block_ids[req.request_id],
|
req_to_new_block_ids[req.request_id])
|
||||||
req.num_computed_tokens)
|
|
||||||
for req in scheduled_new_reqs
|
for req in scheduled_new_reqs
|
||||||
]
|
]
|
||||||
resumed_reqs_data = [
|
resumed_reqs_data = [
|
||||||
self._make_cached_request_data(
|
self._make_cached_request_data(
|
||||||
req,
|
req,
|
||||||
|
num_scheduled_tokens[req.request_id],
|
||||||
|
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||||
req_to_new_block_ids[req.request_id],
|
req_to_new_block_ids[req.request_id],
|
||||||
req.num_computed_tokens,
|
|
||||||
resumed_from_preemption=True,
|
resumed_from_preemption=True,
|
||||||
) for req in scheduled_resumed_reqs
|
) for req in scheduled_resumed_reqs
|
||||||
]
|
]
|
||||||
running_reqs_data = [
|
running_reqs_data = [
|
||||||
self._make_cached_request_data(
|
self._make_cached_request_data(
|
||||||
req,
|
req,
|
||||||
|
num_scheduled_tokens[req.request_id],
|
||||||
|
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||||
req_to_new_block_ids[req.request_id],
|
req_to_new_block_ids[req.request_id],
|
||||||
req.num_computed_tokens,
|
|
||||||
resumed_from_preemption=False,
|
resumed_from_preemption=False,
|
||||||
) for req in scheduled_running_reqs
|
) for req in scheduled_running_reqs
|
||||||
]
|
]
|
||||||
@ -349,8 +356,8 @@ class Scheduler:
|
|||||||
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
|
||||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||||
|
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||||
# finished_req_ids is an existing state in the scheduler,
|
# finished_req_ids is an existing state in the scheduler,
|
||||||
# instead of being newly scheduled in this step.
|
# instead of being newly scheduled in this step.
|
||||||
@ -366,22 +373,28 @@ class Scheduler:
|
|||||||
def _make_cached_request_data(
|
def _make_cached_request_data(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
num_scheduled_tokens: int,
|
||||||
|
num_scheduled_spec_tokens: int,
|
||||||
new_block_ids: List[int],
|
new_block_ids: List[int],
|
||||||
num_computed_tokens: int,
|
|
||||||
resumed_from_preemption: bool,
|
resumed_from_preemption: bool,
|
||||||
) -> "CachedRequestData":
|
) -> "CachedRequestData":
|
||||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||||
# them at each scheduling step.
|
# them at each scheduling step.
|
||||||
if request.request_id in self._cached_reqs_data:
|
num_computed_tokens = request.num_computed_tokens
|
||||||
req_data = self._cached_reqs_data[request.request_id]
|
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
|
||||||
|
new_token_ids = request.all_token_ids[
|
||||||
|
num_computed_tokens:num_computed_tokens + num_regular_tokens]
|
||||||
|
req_data = self._cached_reqs_data.get(request.request_id)
|
||||||
|
if req_data is not None:
|
||||||
req_data.resumed_from_preemption = resumed_from_preemption
|
req_data.resumed_from_preemption = resumed_from_preemption
|
||||||
|
req_data.new_token_ids = new_token_ids
|
||||||
req_data.new_block_ids = new_block_ids
|
req_data.new_block_ids = new_block_ids
|
||||||
req_data.num_computed_tokens = num_computed_tokens
|
req_data.num_computed_tokens = num_computed_tokens
|
||||||
else:
|
else:
|
||||||
req_data = CachedRequestData.from_request(request,
|
req_data = CachedRequestData.from_request(request,
|
||||||
resumed_from_preemption,
|
resumed_from_preemption,
|
||||||
new_block_ids,
|
new_token_ids,
|
||||||
num_computed_tokens)
|
new_block_ids)
|
||||||
self._cached_reqs_data[request.request_id] = req_data
|
self._cached_reqs_data[request.request_id] = req_data
|
||||||
return req_data
|
return req_data
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,6 @@ class NewRequestData:
|
|||||||
cls,
|
cls,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
block_ids: List[int],
|
block_ids: List[int],
|
||||||
num_computed_tokens: int,
|
|
||||||
) -> "NewRequestData":
|
) -> "NewRequestData":
|
||||||
return cls(
|
return cls(
|
||||||
req_id=request.request_id,
|
req_id=request.request_id,
|
||||||
@ -41,7 +40,7 @@ class NewRequestData:
|
|||||||
mm_positions=request.mm_positions,
|
mm_positions=request.mm_positions,
|
||||||
sampling_params=request.sampling_params,
|
sampling_params=request.sampling_params,
|
||||||
block_ids=block_ids,
|
block_ids=block_ids,
|
||||||
num_computed_tokens=num_computed_tokens,
|
num_computed_tokens=request.num_computed_tokens,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -54,6 +53,7 @@ class CachedRequestData:
|
|||||||
# the request's block IDs. If True, new_block_ids will be used as the
|
# the request's block IDs. If True, new_block_ids will be used as the
|
||||||
# request's block IDs instead of appending to the existing block IDs.
|
# request's block IDs instead of appending to the existing block IDs.
|
||||||
resumed_from_preemption: bool
|
resumed_from_preemption: bool
|
||||||
|
new_token_ids: List[int]
|
||||||
new_block_ids: List[int]
|
new_block_ids: List[int]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: int
|
||||||
|
|
||||||
@ -62,14 +62,15 @@ class CachedRequestData:
|
|||||||
cls,
|
cls,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
resumed_from_preemption: bool,
|
resumed_from_preemption: bool,
|
||||||
|
new_token_ids: List[int],
|
||||||
new_block_ids: List[int],
|
new_block_ids: List[int],
|
||||||
num_computed_tokens: int,
|
|
||||||
) -> "CachedRequestData":
|
) -> "CachedRequestData":
|
||||||
return cls(
|
return cls(
|
||||||
req_id=request.request_id,
|
req_id=request.request_id,
|
||||||
resumed_from_preemption=resumed_from_preemption,
|
resumed_from_preemption=resumed_from_preemption,
|
||||||
|
new_token_ids=new_token_ids,
|
||||||
new_block_ids=new_block_ids,
|
new_block_ids=new_block_ids,
|
||||||
num_computed_tokens=num_computed_tokens,
|
num_computed_tokens=request.num_computed_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -91,9 +92,9 @@ class SchedulerOutput:
|
|||||||
# Total number of tokens scheduled for all requests.
|
# Total number of tokens scheduled for all requests.
|
||||||
# Equal to sum(num_scheduled_tokens.values())
|
# Equal to sum(num_scheduled_tokens.values())
|
||||||
total_num_scheduled_tokens: int
|
total_num_scheduled_tokens: int
|
||||||
# req_id -> spec_decode_tokens
|
# req_id -> spec_token_ids
|
||||||
# If a request does not have any spec decode tokens, it will
|
# If a request does not have any spec decode tokens, it will not be
|
||||||
# not be included in the dictionary.
|
# included in the dictionary.
|
||||||
scheduled_spec_decode_tokens: Dict[str, List[int]]
|
scheduled_spec_decode_tokens: Dict[str, List[int]]
|
||||||
# req_id -> encoder input indices that need processing.
|
# req_id -> encoder input indices that need processing.
|
||||||
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -184,7 +184,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
self.max_num_tokens),
|
self.max_num_tokens),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
self.arange_cpu = torch.from_numpy(self.arange_np)
|
|
||||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||||
# a faster version of creating a new tensor every time. Thus, we should
|
# a faster version of creating a new tensor every time. Thus, we should
|
||||||
# not make any assumptions about the values in these tensors.
|
# not make any assumptions about the values in these tensors.
|
||||||
@ -327,7 +326,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
# Update the cached states.
|
# Update the cached states.
|
||||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
num_computed_tokens = req_data.num_computed_tokens
|
||||||
|
req_state.num_computed_tokens = num_computed_tokens
|
||||||
|
# Add the sampled token(s) from the previous step (if any).
|
||||||
|
# This doesn't include "unverified" tokens like spec decode tokens.
|
||||||
|
num_new_tokens = (num_computed_tokens +
|
||||||
|
len(req_data.new_token_ids) -
|
||||||
|
req_state.num_tokens)
|
||||||
|
new_token_ids = (req_data.new_token_ids[-num_new_tokens:]
|
||||||
|
if num_new_tokens > 0 else [])
|
||||||
|
req_state.output_token_ids.extend(new_token_ids)
|
||||||
|
# Update the block IDs.
|
||||||
if not req_data.resumed_from_preemption:
|
if not req_data.resumed_from_preemption:
|
||||||
# Append the new blocks to the existing block IDs.
|
# Append the new blocks to the existing block IDs.
|
||||||
req_state.block_ids.extend(req_data.new_block_ids)
|
req_state.block_ids.extend(req_data.new_block_ids)
|
||||||
@ -346,12 +355,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Update the persistent batch.
|
# Update the persistent batch.
|
||||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||||
req_data.num_computed_tokens)
|
num_computed_tokens)
|
||||||
start_index = len(req_state.block_ids) - len(
|
start_index = (len(req_state.block_ids) -
|
||||||
req_data.new_block_ids)
|
len(req_data.new_block_ids))
|
||||||
self.input_batch.block_table.append_row(req_index, start_index,
|
self.input_batch.block_table.append_row(req_index, start_index,
|
||||||
req_data.new_block_ids)
|
req_data.new_block_ids)
|
||||||
|
# Add new_token_ids to token_ids_cpu.
|
||||||
|
start_token_index = num_computed_tokens
|
||||||
|
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
|
||||||
|
self.input_batch.token_ids_cpu[
|
||||||
|
req_index,
|
||||||
|
start_token_index:end_token_index] = req_data.new_token_ids
|
||||||
|
# Add spec_token_ids to token_ids_cpu.
|
||||||
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||||
|
req_id, [])
|
||||||
|
if spec_token_ids:
|
||||||
|
start_index = end_token_index
|
||||||
|
end_token_index += len(spec_token_ids)
|
||||||
|
self.input_batch.token_ids_cpu[
|
||||||
|
req_index, start_index:end_token_index] = spec_token_ids
|
||||||
|
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
||||||
|
self.input_batch.num_tokens[req_index] = end_token_index
|
||||||
|
|
||||||
|
# Check if the batch has changed. If not, we can skip copying the
|
||||||
|
# sampling metadata from CPU to GPU.
|
||||||
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
# Add the new or resumed requests to the persistent batch.
|
||||||
@ -374,7 +401,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return batch_changed
|
return batch_changed
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self, scheduler_output: "SchedulerOutput"
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
|
) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
@ -387,24 +415,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Get the number of scheduled tokens for each request.
|
# Get the number of scheduled tokens for each request.
|
||||||
# TODO: The Python loop can be slow. Optimize.
|
# TODO: The Python loop can be slow. Optimize.
|
||||||
num_scheduled_tokens_list: List[int] = []
|
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||||
max_num_scheduled_tokens = 0
|
max_num_scheduled_tokens = 0
|
||||||
all_spec_token_ids: List[int] = []
|
|
||||||
num_spec_tokens_list: List[int] = []
|
|
||||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||||
assert req_id is not None
|
assert req_id is not None
|
||||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
num_scheduled_tokens_list.append(num_tokens)
|
num_scheduled_tokens[i] = num_tokens
|
||||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||||
num_tokens)
|
num_tokens)
|
||||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
||||||
req_id, [])
|
|
||||||
all_spec_token_ids.extend(spec_token_ids)
|
|
||||||
num_spec_tokens_list.append(len(spec_token_ids))
|
|
||||||
|
|
||||||
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
|
|
||||||
dtype=np.int32)
|
|
||||||
assert max_num_scheduled_tokens > 0
|
|
||||||
|
|
||||||
# Get request indices.
|
# Get request indices.
|
||||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||||
@ -441,78 +459,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
token_indices = (positions_np +
|
token_indices = (positions_np +
|
||||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||||
|
|
||||||
use_spec_decode = len(all_spec_token_ids) > 0
|
|
||||||
if use_spec_decode:
|
|
||||||
|
|
||||||
# 1. Write spec_token_ids to input batch.
|
|
||||||
# Step 1. Get req indices that perform spec decode and repeat
|
|
||||||
# the req indices by the number of spec tokens. Note
|
|
||||||
# for requests that don't perform spec decode, the
|
|
||||||
# number of spec tokens is 0 and the req index is
|
|
||||||
# repeated 0 times.
|
|
||||||
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
|
|
||||||
# spec_req_indices: [0, 0, 0, 2, 2, 4]
|
|
||||||
spec_req_indices = np.repeat(self.arange_np[:num_reqs],
|
|
||||||
num_spec_tokens_list)
|
|
||||||
# spec_offsets: offsets within each spec token list.
|
|
||||||
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
|
|
||||||
spec_offsets = np.concatenate(
|
|
||||||
[self.arange_np[1:val + 1] for val in num_spec_tokens_list])
|
|
||||||
# spec_seq_offsets: offsets within each sequence.
|
|
||||||
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
|
|
||||||
# after repeating: [1, 1, 1, 3, 3, 2]
|
|
||||||
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
|
|
||||||
# = [2, 3, 4, 4, 5, 3]
|
|
||||||
spec_seq_offsets = np.repeat(
|
|
||||||
self.input_batch.num_computed_tokens_cpu[:num_reqs],
|
|
||||||
num_spec_tokens_list) + spec_offsets
|
|
||||||
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
|
|
||||||
cumsums_spec_offsets = (
|
|
||||||
spec_seq_offsets +
|
|
||||||
spec_req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
||||||
cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to(
|
|
||||||
torch.int64)
|
|
||||||
all_spec_token_ids = torch.tensor(all_spec_token_ids,
|
|
||||||
device="cpu",
|
|
||||||
dtype=self.input_ids_cpu.dtype)
|
|
||||||
|
|
||||||
# Step 2. Write spec token ids to input_ids_cpu.
|
|
||||||
self.input_batch.token_ids_cpu_tensor.flatten().scatter_(
|
|
||||||
0, cumsums_spec_offsets, all_spec_token_ids)
|
|
||||||
|
|
||||||
# 2. Get spec decode logits indices.
|
|
||||||
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
|
|
||||||
# cu_num_tokens: [4, 104, 107, 207, 209]
|
|
||||||
# num_spec_tokens_list: [3, 0, 2, 0, 1]
|
|
||||||
# num_sampled_tokens: [4, 1, 3, 1, 2]
|
|
||||||
# spec_decode_logits_indices:
|
|
||||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
|
||||||
num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32)
|
|
||||||
num_sampled_tokens = num_spec_tokens_np + 1
|
|
||||||
# logits_start_loc: [0, 103, 104, 206, 207]
|
|
||||||
logits_start_loc = cu_num_tokens - num_sampled_tokens
|
|
||||||
# [0, 103, 104, 206, 207] ->
|
|
||||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
|
||||||
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
|
|
||||||
# The following three lines:
|
|
||||||
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
|
||||||
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
|
|
||||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
|
|
||||||
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
|
|
||||||
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
|
||||||
cumsums_sampled_offsets = np.repeat(
|
|
||||||
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
|
|
||||||
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
||||||
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
|
||||||
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
|
||||||
total_num_sampled_tokens = num_sampled_tokens.sum()
|
|
||||||
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
|
|
||||||
cumsums_sampled_offsets)
|
|
||||||
|
|
||||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
|
|
||||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
|
||||||
spec_decode_logits_indices = logits_start_loc + sampled_arange
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||||
# because torch.index_select is much faster than np.take for large
|
# because torch.index_select is much faster than np.take for large
|
||||||
# tensors.
|
# tensors.
|
||||||
@ -606,9 +552,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_spec_decode = len(
|
||||||
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
if use_spec_decode:
|
if use_spec_decode:
|
||||||
logits_indices = torch.from_numpy(spec_decode_logits_indices).to(
|
logits_indices = self._calc_spec_decode_metadata(
|
||||||
self.device, non_blocking=True)
|
scheduler_output, cu_num_tokens)
|
||||||
else:
|
else:
|
||||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||||
# partial requests. While we should not sample any token
|
# partial requests. While we should not sample any token
|
||||||
@ -762,6 +710,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
mrope_pos_ptr += completion_part_len
|
mrope_pos_ptr += completion_part_len
|
||||||
|
|
||||||
|
def _calc_spec_decode_metadata(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
cu_num_tokens: np.ndarray,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Get the number of spec decode tokens for each request.
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||||
|
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||||
|
assert req_id is not None
|
||||||
|
num_spec_decode_tokens[i] = len(
|
||||||
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||||||
|
|
||||||
|
# Get spec decode logits indices.
|
||||||
|
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
|
||||||
|
# cu_num_tokens: [4, 104, 107, 207, 209]
|
||||||
|
# num_spec_tokens_list: [3, 0, 2, 0, 1]
|
||||||
|
# num_sampled_tokens: [4, 1, 3, 1, 2]
|
||||||
|
# spec_decode_logits_indices:
|
||||||
|
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||||
|
num_sampled_tokens = num_spec_decode_tokens + 1
|
||||||
|
# logits_start_loc: [0, 103, 104, 206, 207]
|
||||||
|
logits_start_loc = cu_num_tokens - num_sampled_tokens
|
||||||
|
# [0, 103, 104, 206, 207] ->
|
||||||
|
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||||
|
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
|
||||||
|
# The following three lines:
|
||||||
|
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||||
|
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
|
||||||
|
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
|
||||||
|
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
|
||||||
|
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||||
|
cumsums_sampled_offsets = np.repeat(
|
||||||
|
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||||
|
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
|
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||||
|
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||||
|
total_num_sampled_tokens = num_sampled_tokens.sum()
|
||||||
|
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
|
||||||
|
cumsums_sampled_offsets)
|
||||||
|
|
||||||
|
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
|
||||||
|
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||||
|
spec_decode_logits_indices = logits_start_loc + sampled_arange
|
||||||
|
return torch.from_numpy(spec_decode_logits_indices).to(
|
||||||
|
self.device, non_blocking=True)
|
||||||
|
|
||||||
def _prepare_sampling(
|
def _prepare_sampling(
|
||||||
self,
|
self,
|
||||||
batch_changed: bool,
|
batch_changed: bool,
|
||||||
@ -773,7 +768,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for req_id, req in self.requests.items()}
|
for req_id, req in self.requests.items()}
|
||||||
|
|
||||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||||
req_id_output_token_ids, req_to_spec_token_ids, not batch_changed)
|
req_id_output_token_ids,
|
||||||
|
req_to_spec_token_ids,
|
||||||
|
skip_copy=not batch_changed)
|
||||||
return sampling_metadata
|
return sampling_metadata
|
||||||
|
|
||||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||||
@ -960,28 +957,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||||
# the requests one by one. Optimize.
|
# the requests one by one. Optimize.
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
|
req_ids: List[str] = []
|
||||||
|
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
|
||||||
|
# we need to stop at `num_reqs`.
|
||||||
|
# FIXME(woosuk): This is hacky. Refactor.
|
||||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||||
assert req_id is not None
|
assert req_id is not None
|
||||||
|
req_ids.append(req_id)
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
seq_len = (req_state.num_computed_tokens +
|
seq_len = (req_state.num_computed_tokens +
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
if seq_len >= req_state.num_tokens:
|
if seq_len < req_state.num_tokens:
|
||||||
request_seq_lens.append((i, req_state, seq_len))
|
# Ignore the sampled token.
|
||||||
else:
|
|
||||||
# Ignore the sampled token from the partial request.
|
|
||||||
# Rewind the generator state as if the token was not sampled.
|
# Rewind the generator state as if the token was not sampled.
|
||||||
generator = self.input_batch.generators.get(i)
|
generator = self.input_batch.generators.get(i)
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
# This relies on cuda-specific torch-internal impl details
|
# This relies on cuda-specific torch-internal impl details
|
||||||
generator.set_offset(generator.get_offset() - 4)
|
generator.set_offset(generator.get_offset() - 4)
|
||||||
|
|
||||||
# num_reqs entries should be non-None
|
|
||||||
assert all(
|
|
||||||
req_id is not None for req_id in
|
|
||||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
|
||||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
|
||||||
|
|
||||||
# NOTE: GPU -> CPU Sync happens here.
|
# NOTE: GPU -> CPU Sync happens here.
|
||||||
# Move as many CPU operations as possible before this sync point.
|
# Move as many CPU operations as possible before this sync point.
|
||||||
logprobs_tensors = sampler_output.logprobs_tensors
|
logprobs_tensors = sampler_output.logprobs_tensors
|
||||||
@ -994,29 +987,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output,
|
scheduler_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update batch with the valid generated tokens.
|
# Get the valid generated tokens.
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
if max_gen_len == 1:
|
if max_gen_len == 1:
|
||||||
|
# No spec decode tokens.
|
||||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||||
for i, req_state, seq_len in request_seq_lens:
|
|
||||||
token_id = valid_sampled_token_ids[i][0]
|
|
||||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
|
||||||
req_state.output_token_ids.append(token_id)
|
|
||||||
self.input_batch.num_tokens[i] += 1
|
|
||||||
else:
|
else:
|
||||||
|
# Includes spec decode tokens.
|
||||||
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
|
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
|
||||||
gen_lens = valid_mask.sum(dim=1).tolist()
|
gen_lens = valid_mask.sum(dim=1).tolist()
|
||||||
|
# TODO(woosuk): Optimize this.
|
||||||
valid_sampled_token_ids = [
|
valid_sampled_token_ids = [
|
||||||
seq.tolist()
|
seq.tolist()
|
||||||
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
||||||
]
|
]
|
||||||
self.input_batch.num_tokens[:num_reqs] += gen_lens
|
|
||||||
for i, req_state, seq_len in request_seq_lens:
|
|
||||||
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
|
|
||||||
self.input_batch.token_ids_cpu[
|
|
||||||
i, target_slice] = valid_sampled_token_ids[i]
|
|
||||||
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
|
|
||||||
|
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user