Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-01 09:32:49 -07:00
parent 39a22dcaac
commit 901afda905
2 changed files with 24 additions and 15 deletions

View File

@ -842,6 +842,7 @@ class Scheduler(SchedulerInterface):
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
num_sampled_tokens = model_runner_output.num_sampled_tokens
sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
@ -867,14 +868,17 @@ class Scheduler(SchedulerInterface):
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
num_sampled = num_sampled_tokens[req_index]
if num_sampled > 0:
generated_token_ids = sampled_token_ids[:num_sampled].tolist()
else:
generated_token_ids = []
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1
num_accepted = num_sampled - 1
num_rejected = num_draft_tokens - num_accepted
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled

View File

@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import NamedTuple, Optional
import numpy as np
import torch
@ -88,11 +89,12 @@ class ModelRunnerOutput:
# req_id -> index
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]]
# [num_reqs]
# Number of tokens sampled in the current step. Each request may generate
# different number of tokens due to chunked prefilling and spec decoding.
num_sampled_tokens: np.ndarray
# [num_reqs, max_num_sampled_tokens]
sampled_token_ids: np.ndarray
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
@ -123,10 +125,13 @@ class DraftTokenIds:
draft_token_ids: list[list[int]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None)
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
num_sampled_tokens=np.empty(0, dtype=np.int32),
sampled_token_ids=np.empty((0, 0), dtype=np.int32),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None,
)