mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 02:07:04 +08:00
wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
39a22dcaac
commit
901afda905
@ -842,6 +842,7 @@ class Scheduler(SchedulerInterface):
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> dict[int, EngineCoreOutputs]:
|
||||
num_sampled_tokens = model_runner_output.num_sampled_tokens
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
@ -867,14 +868,17 @@ class Scheduler(SchedulerInterface):
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index] if sampled_token_ids else []
|
||||
num_sampled = num_sampled_tokens[req_index]
|
||||
if num_sampled > 0:
|
||||
generated_token_ids = sampled_token_ids[:num_sampled].tolist()
|
||||
else:
|
||||
generated_token_ids = []
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
if scheduled_spec_token_ids:
|
||||
num_draft_tokens = len(scheduled_spec_token_ids)
|
||||
num_accepted = len(generated_token_ids) - 1
|
||||
num_accepted = num_sampled - 1
|
||||
num_rejected = num_draft_tokens - num_accepted
|
||||
# num_computed_tokens represents the number of tokens
|
||||
# processed in the current step, considering scheduled
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@ -88,11 +89,12 @@ class ModelRunnerOutput:
|
||||
# req_id -> index
|
||||
req_id_to_index: dict[str, int]
|
||||
|
||||
# num_reqs x num_generated_tokens
|
||||
# num_generated_tokens is the number of tokens
|
||||
# generated in the current step. It can be different for
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: list[list[int]]
|
||||
# [num_reqs]
|
||||
# Number of tokens sampled in the current step. Each request may generate
|
||||
# different number of tokens due to chunked prefilling and spec decoding.
|
||||
num_sampled_tokens: np.ndarray
|
||||
# [num_reqs, max_num_sampled_tokens]
|
||||
sampled_token_ids: np.ndarray
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
@ -123,10 +125,13 @@ class DraftTokenIds:
|
||||
draft_token_ids: list[list[int]]
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
num_nans_in_logits=None)
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[],
|
||||
req_id_to_index={},
|
||||
num_sampled_tokens=np.empty(0, dtype=np.int32),
|
||||
sampled_token_ids=np.empty((0, 0), dtype=np.int32),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user