Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-01 09:32:49 -07:00
parent 39a22dcaac
commit 901afda905
2 changed files with 24 additions and 15 deletions

View File

@ -842,6 +842,7 @@ class Scheduler(SchedulerInterface):
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]: ) -> dict[int, EngineCoreOutputs]:
num_sampled_tokens = model_runner_output.num_sampled_tokens
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
@ -867,14 +868,17 @@ class Scheduler(SchedulerInterface):
continue continue
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[ num_sampled = num_sampled_tokens[req_index]
req_index] if sampled_token_ids else [] if num_sampled > 0:
generated_token_ids = sampled_token_ids[:num_sampled].tolist()
else:
generated_token_ids = []
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids: if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids) num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1 num_accepted = num_sampled - 1
num_rejected = num_draft_tokens - num_accepted num_rejected = num_draft_tokens - num_accepted
# num_computed_tokens represents the number of tokens # num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled # processed in the current step, considering scheduled

View File

@ -4,6 +4,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
import numpy as np
import torch import torch
@ -88,11 +89,12 @@ class ModelRunnerOutput:
# req_id -> index # req_id -> index
req_id_to_index: dict[str, int] req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens # [num_reqs]
# num_generated_tokens is the number of tokens # Number of tokens sampled in the current step. Each request may generate
# generated in the current step. It can be different for # different number of tokens due to chunked prefilling and spec decoding.
# each request due to speculative/jump decoding. num_sampled_tokens: np.ndarray
sampled_token_ids: list[list[int]] # [num_reqs, max_num_sampled_tokens]
sampled_token_ids: np.ndarray
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
@ -123,10 +125,13 @@ class DraftTokenIds:
draft_token_ids: list[list[int]] draft_token_ids: list[list[int]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_id_to_index={}, req_ids=[],
sampled_token_ids=[], req_id_to_index={},
logprobs=None, num_sampled_tokens=np.empty(0, dtype=np.int32),
prompt_logprobs_dict={}, sampled_token_ids=np.empty((0, 0), dtype=np.int32),
pooler_output=[], logprobs=None,
num_nans_in_logits=None) prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None,
)