mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[V1] Get input tokens from scheduler (#13339)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
ce77eb9410
commit
4c21ce9eba
@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
cached_req_data = CachedRequestData(
|
||||
req_id=req_id,
|
||||
resumed_from_preemption=False,
|
||||
new_token_ids=[],
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=0,
|
||||
)
|
||||
|
||||
@ -121,6 +121,8 @@ class Scheduler:
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
|
||||
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
@ -187,6 +189,15 @@ class Scheduler:
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids[:num_scheduled_spec_tokens])
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
@ -196,11 +207,6 @@ class Scheduler:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
scheduled_spec_decode_tokens[
|
||||
request.request_id] = request.spec_token_ids
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
requested_loras: Set[int] = set()
|
||||
if self.lora_config:
|
||||
@ -324,23 +330,24 @@ class Scheduler:
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(req,
|
||||
req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens)
|
||||
req_to_new_block_ids[req.request_id])
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
resumed_reqs_data = [
|
||||
self._make_cached_request_data(
|
||||
req,
|
||||
num_scheduled_tokens[req.request_id],
|
||||
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||
req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens,
|
||||
resumed_from_preemption=True,
|
||||
) for req in scheduled_resumed_reqs
|
||||
]
|
||||
running_reqs_data = [
|
||||
self._make_cached_request_data(
|
||||
req,
|
||||
num_scheduled_tokens[req.request_id],
|
||||
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||
req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens,
|
||||
resumed_from_preemption=False,
|
||||
) for req in scheduled_running_reqs
|
||||
]
|
||||
@ -349,8 +356,8 @@ class Scheduler:
|
||||
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
@ -366,22 +373,28 @@ class Scheduler:
|
||||
def _make_cached_request_data(
|
||||
self,
|
||||
request: Request,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_spec_tokens: int,
|
||||
new_block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
resumed_from_preemption: bool,
|
||||
) -> "CachedRequestData":
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
if request.request_id in self._cached_reqs_data:
|
||||
req_data = self._cached_reqs_data[request.request_id]
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_computed_tokens:num_computed_tokens + num_regular_tokens]
|
||||
req_data = self._cached_reqs_data.get(request.request_id)
|
||||
if req_data is not None:
|
||||
req_data.resumed_from_preemption = resumed_from_preemption
|
||||
req_data.new_token_ids = new_token_ids
|
||||
req_data.new_block_ids = new_block_ids
|
||||
req_data.num_computed_tokens = num_computed_tokens
|
||||
else:
|
||||
req_data = CachedRequestData.from_request(request,
|
||||
resumed_from_preemption,
|
||||
new_block_ids,
|
||||
num_computed_tokens)
|
||||
new_token_ids,
|
||||
new_block_ids)
|
||||
self._cached_reqs_data[request.request_id] = req_data
|
||||
return req_data
|
||||
|
||||
|
||||
@ -30,7 +30,6 @@ class NewRequestData:
|
||||
cls,
|
||||
request: "Request",
|
||||
block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "NewRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
@ -41,7 +40,7 @@ class NewRequestData:
|
||||
mm_positions=request.mm_positions,
|
||||
sampling_params=request.sampling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
|
||||
@ -54,6 +53,7 @@ class CachedRequestData:
|
||||
# the request's block IDs. If True, new_block_ids will be used as the
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_from_preemption: bool
|
||||
new_token_ids: List[int]
|
||||
new_block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
|
||||
@ -62,14 +62,15 @@ class CachedRequestData:
|
||||
cls,
|
||||
request: "Request",
|
||||
resumed_from_preemption: bool,
|
||||
new_token_ids: List[int],
|
||||
new_block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "CachedRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
new_token_ids=new_token_ids,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -91,9 +92,9 @@ class SchedulerOutput:
|
||||
# Total number of tokens scheduled for all requests.
|
||||
# Equal to sum(num_scheduled_tokens.values())
|
||||
total_num_scheduled_tokens: int
|
||||
# req_id -> spec_decode_tokens
|
||||
# If a request does not have any spec decode tokens, it will
|
||||
# not be included in the dictionary.
|
||||
# req_id -> spec_token_ids
|
||||
# If a request does not have any spec decode tokens, it will not be
|
||||
# included in the dictionary.
|
||||
scheduled_spec_decode_tokens: Dict[str, List[int]]
|
||||
# req_id -> encoder input indices that need processing.
|
||||
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -184,7 +184,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.max_model_len,
|
||||
self.max_num_tokens),
|
||||
dtype=np.int32)
|
||||
self.arange_cpu = torch.from_numpy(self.arange_np)
|
||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||
# a faster version of creating a new tensor every time. Thus, we should
|
||||
# not make any assumptions about the values in these tensors.
|
||||
@ -327,7 +326,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
# Update the cached states.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
num_computed_tokens = req_data.num_computed_tokens
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
# Add the sampled token(s) from the previous step (if any).
|
||||
# This doesn't include "unverified" tokens like spec decode tokens.
|
||||
num_new_tokens = (num_computed_tokens +
|
||||
len(req_data.new_token_ids) -
|
||||
req_state.num_tokens)
|
||||
new_token_ids = (req_data.new_token_ids[-num_new_tokens:]
|
||||
if num_new_tokens > 0 else [])
|
||||
req_state.output_token_ids.extend(new_token_ids)
|
||||
# Update the block IDs.
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
@ -346,12 +355,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
start_index = len(req_state.block_ids) - len(
|
||||
req_data.new_block_ids)
|
||||
num_computed_tokens)
|
||||
start_index = (len(req_state.block_ids) -
|
||||
len(req_data.new_block_ids))
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
# Add new_token_ids to token_ids_cpu.
|
||||
start_token_index = num_computed_tokens
|
||||
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index,
|
||||
start_token_index:end_token_index] = req_data.new_token_ids
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, [])
|
||||
if spec_token_ids:
|
||||
start_index = end_token_index
|
||||
end_token_index += len(spec_token_ids)
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index, start_index:end_token_index] = spec_token_ids
|
||||
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
||||
self.input_batch.num_tokens[req_index] = end_token_index
|
||||
|
||||
# Check if the batch has changed. If not, we can skip copying the
|
||||
# sampling metadata from CPU to GPU.
|
||||
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
@ -374,7 +401,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return batch_changed
|
||||
|
||||
def _prepare_inputs(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@ -387,24 +415,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens_list: List[int] = []
|
||||
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
max_num_scheduled_tokens = 0
|
||||
all_spec_token_ids: List[int] = []
|
||||
num_spec_tokens_list: List[int] = []
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens_list.append(num_tokens)
|
||||
num_scheduled_tokens[i] = num_tokens
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, [])
|
||||
all_spec_token_ids.extend(spec_token_ids)
|
||||
num_spec_tokens_list.append(len(spec_token_ids))
|
||||
|
||||
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
assert max_num_scheduled_tokens > 0
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
@ -441,78 +459,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
|
||||
use_spec_decode = len(all_spec_token_ids) > 0
|
||||
if use_spec_decode:
|
||||
|
||||
# 1. Write spec_token_ids to input batch.
|
||||
# Step 1. Get req indices that perform spec decode and repeat
|
||||
# the req indices by the number of spec tokens. Note
|
||||
# for requests that don't perform spec decode, the
|
||||
# number of spec tokens is 0 and the req index is
|
||||
# repeated 0 times.
|
||||
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
|
||||
# spec_req_indices: [0, 0, 0, 2, 2, 4]
|
||||
spec_req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_spec_tokens_list)
|
||||
# spec_offsets: offsets within each spec token list.
|
||||
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
|
||||
spec_offsets = np.concatenate(
|
||||
[self.arange_np[1:val + 1] for val in num_spec_tokens_list])
|
||||
# spec_seq_offsets: offsets within each sequence.
|
||||
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
|
||||
# after repeating: [1, 1, 1, 3, 3, 2]
|
||||
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
|
||||
# = [2, 3, 4, 4, 5, 3]
|
||||
spec_seq_offsets = np.repeat(
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs],
|
||||
num_spec_tokens_list) + spec_offsets
|
||||
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
|
||||
cumsums_spec_offsets = (
|
||||
spec_seq_offsets +
|
||||
spec_req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to(
|
||||
torch.int64)
|
||||
all_spec_token_ids = torch.tensor(all_spec_token_ids,
|
||||
device="cpu",
|
||||
dtype=self.input_ids_cpu.dtype)
|
||||
|
||||
# Step 2. Write spec token ids to input_ids_cpu.
|
||||
self.input_batch.token_ids_cpu_tensor.flatten().scatter_(
|
||||
0, cumsums_spec_offsets, all_spec_token_ids)
|
||||
|
||||
# 2. Get spec decode logits indices.
|
||||
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
|
||||
# cu_num_tokens: [4, 104, 107, 207, 209]
|
||||
# num_spec_tokens_list: [3, 0, 2, 0, 1]
|
||||
# num_sampled_tokens: [4, 1, 3, 1, 2]
|
||||
# spec_decode_logits_indices:
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32)
|
||||
num_sampled_tokens = num_spec_tokens_np + 1
|
||||
# logits_start_loc: [0, 103, 104, 206, 207]
|
||||
logits_start_loc = cu_num_tokens - num_sampled_tokens
|
||||
# [0, 103, 104, 206, 207] ->
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
|
||||
# The following three lines:
|
||||
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
|
||||
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
|
||||
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
cumsums_sampled_offsets = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
total_num_sampled_tokens = num_sampled_tokens.sum()
|
||||
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
|
||||
cumsums_sampled_offsets)
|
||||
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
spec_decode_logits_indices = logits_start_loc + sampled_arange
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
@ -606,9 +552,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if use_spec_decode:
|
||||
logits_indices = torch.from_numpy(spec_decode_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
logits_indices = self._calc_spec_decode_metadata(
|
||||
scheduler_output, cu_num_tokens)
|
||||
else:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
@ -762,6 +710,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
mrope_pos_ptr += completion_part_len
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
cu_num_tokens: np.ndarray,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Get the number of spec decode tokens for each request.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
num_spec_decode_tokens[i] = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||||
|
||||
# Get spec decode logits indices.
|
||||
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
|
||||
# cu_num_tokens: [4, 104, 107, 207, 209]
|
||||
# num_spec_tokens_list: [3, 0, 2, 0, 1]
|
||||
# num_sampled_tokens: [4, 1, 3, 1, 2]
|
||||
# spec_decode_logits_indices:
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
num_sampled_tokens = num_spec_decode_tokens + 1
|
||||
# logits_start_loc: [0, 103, 104, 206, 207]
|
||||
logits_start_loc = cu_num_tokens - num_sampled_tokens
|
||||
# [0, 103, 104, 206, 207] ->
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
|
||||
# The following three lines:
|
||||
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
|
||||
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
|
||||
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
cumsums_sampled_offsets = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
total_num_sampled_tokens = num_sampled_tokens.sum()
|
||||
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
|
||||
cumsums_sampled_offsets)
|
||||
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
spec_decode_logits_indices = logits_start_loc + sampled_arange
|
||||
return torch.from_numpy(spec_decode_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
batch_changed: bool,
|
||||
@ -773,7 +768,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for req_id, req in self.requests.items()}
|
||||
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, req_to_spec_token_ids, not batch_changed)
|
||||
req_id_output_token_ids,
|
||||
req_to_spec_token_ids,
|
||||
skip_copy=not batch_changed)
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
@ -960,28 +957,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
|
||||
req_ids: List[str] = []
|
||||
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
|
||||
# we need to stop at `num_reqs`.
|
||||
# FIXME(woosuk): This is hacky. Refactor.
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
req_ids.append(req_id)
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len >= req_state.num_tokens:
|
||||
request_seq_lens.append((i, req_state, seq_len))
|
||||
else:
|
||||
# Ignore the sampled token from the partial request.
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
@ -994,29 +987,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
# Update batch with the valid generated tokens.
|
||||
# Get the valid generated tokens.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = valid_sampled_token_ids[i][0]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids.append(token_id)
|
||||
self.input_batch.num_tokens[i] += 1
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
|
||||
gen_lens = valid_mask.sum(dim=1).tolist()
|
||||
# TODO(woosuk): Optimize this.
|
||||
valid_sampled_token_ids = [
|
||||
seq.tolist()
|
||||
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
||||
]
|
||||
self.input_batch.num_tokens[:num_reqs] += gen_lens
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
|
||||
self.input_batch.token_ids_cpu[
|
||||
i, target_slice] = valid_sampled_token_ids[i]
|
||||
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user