Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Cyrus Leung 2025-11-15 14:47:41 +08:00 committed by GitHub
parent 6965ef436f
commit 98b4d389ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 122 additions and 91 deletions

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque from collections import deque
import numpy as np
import pytest import pytest
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@ -21,7 +22,7 @@ def _make_model_runner_output(
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)},
sampled_token_ids=[[i] for i in range(len(req_ids))], sampled_token_ids=[np.array([i]) for i in range(len(req_ids))],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],

View File

@ -3,6 +3,7 @@
import random import random
import uuid import uuid
import numpy as np
import pytest import pytest
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -99,8 +100,7 @@ def _mock_execute_model(
random.randint(*num_output_tokens_range) for _ in range(len(request_ids)) random.randint(*num_output_tokens_range) for _ in range(len(request_ids))
] ]
sampled_token_ids = [ sampled_token_ids = [
[random.randint(0, 100) for _ in range(num_tokens)] np.random.randint(0, 100, size=num_tokens) for num_tokens in num_output_tokens
for num_tokens in num_output_tokens
] ]
return ModelRunnerOutput( return ModelRunnerOutput(
@ -196,6 +196,8 @@ def test_priority_scheduling_blast(
num_blocks: int, num_blocks: int,
): ):
random.seed(42) random.seed(42)
np.random.seed(42)
seen_request_prompt_length = dict[str, int]() seen_request_prompt_length = dict[str, int]()
seen_request_ids = set[str]() seen_request_ids = set[str]()
seen_mm_hashes = set[str]() seen_mm_hashes = set[str]()

View File

@ -3,6 +3,7 @@
import dataclasses import dataclasses
from unittest.mock import Mock from unittest.mock import Mock
import numpy as np
import pytest import pytest
import torch import torch
@ -169,7 +170,7 @@ def test_schedule_partial_requests():
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
# Only the first request has a sampled token id because # Only the first request has a sampled token id because
# the rest requests are still being prefilled. # the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []], sampled_token_ids=[np.array([0]), np.array([]), np.array([])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -216,7 +217,7 @@ def test_no_mm_input_chunking():
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests], req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))], sampled_token_ids=[np.array([]) for _ in range(len(requests))],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -276,7 +277,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests], req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))], sampled_token_ids=[np.array([]) for _ in range(len(requests))],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -300,7 +301,8 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests], req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], sampled_token_ids=[np.array([0]), np.array([0])]
+ [np.array([]) for _ in range(len(requests) - 2)],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -347,8 +349,8 @@ def test_stop_via_update_from_output():
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[ sampled_token_ids=[
[EOS_TOKEN_ID], np.array([EOS_TOKEN_ID]),
[10, 11], np.array([10, 11]),
], # First request hits EOS, second continues ], # First request hits EOS, second continues
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -392,7 +394,10 @@ def test_stop_via_update_from_output():
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token sampled_token_ids=[
np.array([10, 42, 12]),
np.array([13, 14]),
], # First request hits stop token
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -436,7 +441,10 @@ def test_stop_via_update_from_output():
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens sampled_token_ids=[
np.array([10, 11, 12]),
np.array([13]),
], # First request exceeds max_tokens
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -475,7 +483,7 @@ def test_stop_via_update_from_output():
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id], req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0}, req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -616,7 +624,7 @@ def test_schedule_concurrent_batches(
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id], req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0}, req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -633,7 +641,7 @@ def test_schedule_concurrent_batches(
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id], req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0}, req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -670,7 +678,7 @@ def test_preempt_during_execution():
model_runner_output0 = ModelRunnerOutput( model_runner_output0 = ModelRunnerOutput(
req_ids=[requests[0].request_id], req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0}, req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -687,7 +695,7 @@ def test_preempt_during_execution():
model_runner_output1 = ModelRunnerOutput( model_runner_output1 = ModelRunnerOutput(
req_ids=[requests[1].request_id], req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0}, req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[42]], sampled_token_ids=[np.array([42])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -704,14 +712,18 @@ def test_preempt_during_execution():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"spec_tokens,output_tokens,expected", "spec_tokens,output_tokens,expected",
[ [
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match ([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch ([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences (
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence [[1, 2], [3]],
([[]], [[5]], (0, 0, 0, [0])), # empty sequence [np.array([1, 2, 5]), np.array([3, 4])],
(2, 3, 3, [2, 1]),
), # multiple sequences
([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence
([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence
( (
[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
[[1, 2, 7], [4, 8]], [np.array([1, 2, 7]), np.array([4, 8])],
(2, 6, 3, [2, 1, 0]), (2, 6, 3, [2, 1, 0]),
), # multiple mismatches ), # multiple mismatches
], ],
@ -745,7 +757,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))], sampled_token_ids=[np.array([0]) for _ in range(len(requests))],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -972,7 +984,7 @@ def test_kv_connector_basic(is_async: bool):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[np.array([1000])] * len(req_ids),
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1025,7 +1037,7 @@ def test_kv_connector_basic(is_async: bool):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[np.array([1000])] * len(req_ids),
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1088,7 +1100,7 @@ def test_external_prefix_cache_metrics():
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[r.request_id for r in requests], req_ids=[r.request_id for r in requests],
req_id_to_index={r.request_id: i for i, r in enumerate(requests)}, req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
sampled_token_ids=[[1000]] * NUM_REQUESTS, sampled_token_ids=[np.array([1000])] * NUM_REQUESTS,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1154,7 +1166,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[np.array([1000])] * len(req_ids),
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1239,7 +1251,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[np.array([1000])] * len(req_ids),
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1332,7 +1344,7 @@ def make_output(scheduler: Scheduler):
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running], req_ids=[req.request_id for req in scheduler.running],
req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)}, req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)},
sampled_token_ids=[[1000]] * len(scheduler.running), sampled_token_ids=[np.array([1000])] * len(scheduler.running),
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1749,7 +1761,7 @@ def test_priority_scheduling_preemption():
req_id_to_index={ req_id_to_index={
req.request_id: i for i, req in enumerate(low_priority_requests) req.request_id: i for i, req in enumerate(low_priority_requests)
}, },
sampled_token_ids=[[100] for _ in low_priority_requests], sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1818,7 +1830,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
req_id_to_index={ req_id_to_index={
req.request_id: i for i, req in enumerate(low_priority_requests) req.request_id: i for i, req in enumerate(low_priority_requests)
}, },
sampled_token_ids=[[100] for _ in low_priority_requests], sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -2064,7 +2076,7 @@ def test_priority_scheduling_heap_property():
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.req_id], req_ids=[req.req_id],
req_id_to_index={req.req_id: 0}, req_id_to_index={req.req_id: 0},
sampled_token_ids=[[100]], sampled_token_ids=[np.array([100])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -2150,7 +2162,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[request_low.request_id], req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0}, req_id_to_index={request_low.request_id: 0},
sampled_token_ids=[[100]], sampled_token_ids=[np.array([100])],
# spec_token_ids=None, # spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -2181,7 +2193,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[100] for _ in requests], sampled_token_ids=[np.array([100]) for _ in requests],
# spec_token_ids=None, # spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -2207,7 +2219,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[], [100]], sampled_token_ids=[np.array([]), np.array([100])],
# spec_token_ids=None, # spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -2624,7 +2636,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[request1.request_id], req_ids=[request1.request_id],
req_id_to_index={request1.request_id: 0}, req_id_to_index={request1.request_id: 0},
sampled_token_ids=[[100]], sampled_token_ids=[np.array([100])],
# spec_token_ids=None, # spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -2830,7 +2842,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[np.array([1000])] * len(req_ids),
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -2943,7 +2955,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[request_low.request_id], req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0}, req_id_to_index={request_low.request_id: 0},
sampled_token_ids=[[100]], sampled_token_ids=[np.array([100])],
# spec_token_ids=None, # spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -2994,7 +3006,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[100] for _ in requests], sampled_token_ids=[np.array([100]) for _ in requests],
# spec_token_ids=None, # spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -3029,7 +3041,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[100], [100, 200]], sampled_token_ids=[np.array([100]), np.array([100, 200])],
# spec_token_ids=None, # spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -3215,7 +3227,7 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[request1.request_id, request2.request_id], req_ids=[request1.request_id, request2.request_id],
req_id_to_index={request1.request_id: 0, request2.request_id: 1}, req_id_to_index={request1.request_id: 0, request2.request_id: 1},
sampled_token_ids=[[100], [121]], sampled_token_ids=[np.array([100]), np.array([121])],
# spec_token_ids=None, # spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},

View File

@ -11,6 +11,7 @@ import uuid
from collections import defaultdict from collections import defaultdict
from unittest.mock import patch from unittest.mock import patch
import numpy as np
import pytest import pytest
import ray import ray
import torch import torch
@ -826,7 +827,7 @@ def test_kv_connector_stats_aggregation():
output = ModelRunnerOutput( output = ModelRunnerOutput(
req_ids=[f"req_{i}"], req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0}, req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], # dummy token sampled_token_ids=[np.array([123])], # dummy token
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[None], pooler_output=[None],
@ -907,7 +908,7 @@ def test_multi_kv_connector_stats_aggregation():
output = ModelRunnerOutput( output = ModelRunnerOutput(
req_ids=[f"req_{i}"], req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0}, req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], sampled_token_ids=[np.array([123])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[None], pooler_output=[None],
@ -965,7 +966,7 @@ def test_scheduler_kv_connector_stats_aggregation():
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=["req_0"], req_ids=["req_0"],
req_id_to_index={"req_0": 0}, req_id_to_index={"req_0": 0},
sampled_token_ids=[[123]], sampled_token_ids=[np.array([123])],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[None], pooler_output=[None],

View File

@ -7,6 +7,7 @@ from dataclasses import dataclass
from itertools import chain, count from itertools import chain, count
from typing import Any from typing import Any
import numpy as np
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
@ -228,7 +229,7 @@ def create_model_runner_output(
# Make sampled tokens. # Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else token_id sampled_token = EOS_TOKEN_ID if use_eos else token_id
sampled_token_ids = [[sampled_token] for _ in req_ids] sampled_token_ids = [np.array([sampled_token]) for _ in req_ids]
kv_connector_output = ( kv_connector_output = (
None None

View File

@ -3,6 +3,7 @@
from unittest import mock from unittest import mock
import numpy as np
import pytest import pytest
import torch import torch
@ -112,7 +113,9 @@ def test_prepare_next_token_ids():
sampled_token_ids_tensor = torch.tensor( sampled_token_ids_tensor = torch.tensor(
sampled_token_ids, dtype=torch.int32, device=device sampled_token_ids, dtype=torch.int32, device=device
) )
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] sampled_token_ids_cpu = [
np.array([i for i in seq if i != -1]) for seq in sampled_token_ids
]
expected_next_token_ids_cpu = [1, 4, 30, 40] expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor( expected_next_token_ids_tensor = torch.tensor(

View File

@ -77,7 +77,7 @@ def test_ngram_proposer():
# No match. # No match.
token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
req_ids=["0"], req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
@ -88,7 +88,7 @@ def test_ngram_proposer():
# No match for 4-gram. # No match for 4-gram.
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
req_ids=["0"], req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
@ -99,7 +99,7 @@ def test_ngram_proposer():
# No match for 4-gram but match for 3-gram. # No match for 4-gram but match for 3-gram.
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
req_ids=["0"], req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
@ -111,7 +111,7 @@ def test_ngram_proposer():
# In this case, the proposer should return the 4-gram match. # In this case, the proposer should return the 4-gram match.
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
req_ids=["0"], req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
@ -122,7 +122,7 @@ def test_ngram_proposer():
# Match for 2-gram and 3-gram, but not 4-gram. # Match for 2-gram and 3-gram, but not 4-gram.
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
req_ids=["0"], req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
@ -133,7 +133,7 @@ def test_ngram_proposer():
# Multiple 3-gram matched, but always pick the first one. # Multiple 3-gram matched, but always pick the first one.
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
req_ids=["0"], req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
@ -144,7 +144,7 @@ def test_ngram_proposer():
# check empty input # check empty input
token_ids_cpu = np.array([[]]) token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]], sampled_token_ids=[np.array([0])],
req_ids=["0"], req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
@ -157,7 +157,7 @@ def test_ngram_proposer():
# second request has 3 tokens and no match. Padded with -1 for max len 5 # second request has 3 tokens and no match. Padded with -1 for max len 5
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0], [1]], sampled_token_ids=[np.array([0]), np.array([1])],
req_ids=["0", "1"], req_ids=["0", "1"],
num_tokens_no_spec=np.array([5, 3]), num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,
@ -181,7 +181,7 @@ def test_ngram_proposer():
input_2[:3] = [4, 5, 6] input_2[:3] = [4, 5, 6]
token_ids_cpu = np.array([input_1, input_2]) token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose( result = ngram_proposer.propose(
sampled_token_ids=[[0], [1]], sampled_token_ids=[np.array([0]), np.array([1])],
req_ids=["0", "1"], req_ids=["0", "1"],
num_tokens_no_spec=np.array([len(input_1), 3]), num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu, token_ids_cpu=token_ids_cpu,

View File

@ -1010,8 +1010,8 @@ class Scheduler(SchedulerInterface):
continue continue
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = ( generated_token_ids: list[int] = (
sampled_token_ids[req_index] if sampled_token_ids else [] sampled_token_ids[req_index].tolist() if sampled_token_ids else []
) )
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (

View File

@ -158,7 +158,7 @@ class ModelRunnerOutput:
# num_generated_tokens is the number of tokens # num_generated_tokens is the number of tokens
# generated in the current step. It can be different for # generated in the current step. It can be different for
# each request due to speculative/jump decoding. # each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]] sampled_token_ids: list[np.ndarray]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
@ -220,7 +220,7 @@ def make_empty_encoder_model_runner_output(
req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)} req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}
# No tokens generated yet ⇒ one empty list per request # No tokens generated yet ⇒ one empty list per request
sampled_token_ids: list[list[int]] = [[0] for _ in req_ids] sampled_token_ids: list[list[int]] = [np.array([0]) for _ in req_ids]
# Pooler outputs are not available yet ⇒ use None placeholders # Pooler outputs are not available yet ⇒ use None placeholders
pooler_output: list[torch.Tensor | None] = [None for _ in req_ids] pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]

View File

@ -3,6 +3,7 @@
from dataclasses import replace from dataclasses import replace
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -204,7 +205,7 @@ class RejectionSampler(nn.Module):
def parse_output( def parse_output(
output_token_ids: torch.Tensor, output_token_ids: torch.Tensor,
vocab_size: int, vocab_size: int,
) -> list[list[int]]: ) -> list[np.ndarray]:
"""Parse the output of the rejection sampler. """Parse the output of the rejection sampler.
Args: Args:
output_token_ids: The sampled token IDs in shape output_token_ids: The sampled token IDs in shape
@ -220,10 +221,7 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size output_token_ids_np < vocab_size
) )
outputs = [ return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)]
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
def apply_logits_processors( def apply_logits_processors(
self, self,

View File

@ -484,7 +484,7 @@ class EagleProposer:
def prepare_next_token_ids_cpu( def prepare_next_token_ids_cpu(
self, self,
sampled_token_ids: list[list[int]], sampled_token_ids: list[np.ndarray],
requests: dict[str, CachedRequestState], requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch, gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int], num_scheduled_tokens: dict[str, int],
@ -499,7 +499,7 @@ class EagleProposer:
req_ids = gpu_input_batch.req_ids req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = [] next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids): for i, token_ids in enumerate(sampled_token_ids):
if token_ids: if token_ids.shape[0] > 0:
# Common case. # Common case.
next_token_id = token_ids[-1] next_token_id = token_ids[-1]
else: else:
@ -510,10 +510,9 @@ class EagleProposer:
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
next_token_id = req_state.get_token_id(seq_len) next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id) next_token_ids.append(next_token_id)
next_token_ids = torch.tensor( return torch.tensor(
next_token_ids, dtype=torch.int32, device=self.input_ids.device next_token_ids, dtype=torch.int32, device=self.input_ids.device
) )
return next_token_ids
def prepare_next_token_ids_padded( def prepare_next_token_ids_padded(
self, self,

View File

@ -54,7 +54,7 @@ class NgramProposer:
# Trigger Numba JIT compilation for N-gram proposer. # Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second. # This usually takes less than 1 second.
self.propose( self.propose(
[[]] * 1024, [np.array([])] * 1024,
[""] * 1024, [""] * 1024,
np.zeros(1024, dtype=np.int32), np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32),
@ -131,7 +131,7 @@ class NgramProposer:
def propose( def propose(
self, self,
sampled_token_ids: list[list[int]], sampled_token_ids: list[np.ndarray],
req_ids: list[str], req_ids: list[str],
num_tokens_no_spec: np.ndarray, num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray, token_ids_cpu: np.ndarray,
@ -140,7 +140,7 @@ class NgramProposer:
# find which requests need ngram proposals # find which requests need ngram proposals
valid_ngram_requests = [] valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids): for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids) num_sampled_ids = sampled_ids.shape[0]
if not num_sampled_ids: if not num_sampled_ids:
# Skip speculative decoding. # Skip speculative decoding.
continue continue

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
@ -32,16 +34,16 @@ class SuffixDecodingProposer:
def propose( def propose(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
sampled_token_ids: list[list[int]], sampled_token_ids: list[np.ndarray],
) -> list[list[int]]: ) -> list[list[int]]:
""" """
Propose speculative tokens for each request in the input batch. Suffix Decoding Propose speculative tokens for each request in the input batch. Suffix Decoding
will speculate a dynamic number of tokens for each request every decoding step, will speculate a dynamic number of tokens for each request every decoding step,
so each entry in the returned list may have different lengths. so each entry in the returned list may have different lengths.
""" """
draft_token_ids: list[list[int]] = [] draft_token_ids: list[np.ndarray] = []
for i, sampled_ids in enumerate(sampled_token_ids): for i, sampled_ids in enumerate(sampled_token_ids):
if not sampled_ids: if sampled_ids.shape[0] == 0:
# Skip speculative decoding for partial prefills. # Skip speculative decoding for partial prefills.
draft_token_ids.append([]) draft_token_ids.append([])
continue continue
@ -70,7 +72,7 @@ class SuffixDecodingProposer:
self.suffix_cache.start_request(req_id, prompt_token_ids) self.suffix_cache.start_request(req_id, prompt_token_ids)
# Append the newly sampled ids to the suffix cache for this request. # Append the newly sampled ids to the suffix cache for this request.
self.suffix_cache.add_active_response(req_id, sampled_ids) self.suffix_cache.add_active_response(req_id, sampled_ids.tolist())
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so # Suffix decoding only uses the most recent tokens up to max_tree_depth, so
# we extract the pattern from the end of the input. # we extract the pattern from the end of the input.

View File

@ -216,9 +216,11 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
del self._logprobs_tensors del self._logprobs_tensors
del self._sampled_token_ids del self._sampled_token_ids
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() valid_sampled_token_ids: list[np.ndarray] = [
row for row in self.sampled_token_ids_cpu.numpy()
]
for i in self._invalid_req_indices: for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i] = np.array([])
output = self._model_runner_output output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids output.sampled_token_ids = valid_sampled_token_ids
@ -2339,7 +2341,7 @@ class GPUModelRunner(
) -> tuple[ ) -> tuple[
dict[str, int], dict[str, int],
LogprobsLists | None, LogprobsLists | None,
list[list[int]], list[np.ndarray],
dict[str, LogprobsTensors | None], dict[str, LogprobsTensors | None],
list[str], list[str],
dict[str, int], dict[str, int],
@ -2365,6 +2367,7 @@ class GPUModelRunner(
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
invalid_req_indices = [] invalid_req_indices = []
valid_sampled_token_ids: list[np.ndarray]
if not self.use_async_scheduling: if not self.use_async_scheduling:
# Get the valid generated tokens. # Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
@ -2379,7 +2382,7 @@ class GPUModelRunner(
) )
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear() valid_sampled_token_ids[int(i)] = np.array([])
else: else:
valid_sampled_token_ids = [] valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
@ -2407,19 +2410,24 @@ class GPUModelRunner(
[0] if spec_decode_metadata and logprobs_tensors else None [0] if spec_decode_metadata and logprobs_tensors else None
) )
for req_idx in range(num_sampled_tokens): for req_idx in range(num_sampled_tokens):
sampled_ids: np.ndarray | None
if self.use_async_scheduling: if self.use_async_scheduling:
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None sampled_ids = (
np.array([-1]) if req_idx not in invalid_req_indices_set else None
)
else: else:
sampled_ids = valid_sampled_token_ids[req_idx] sampled_ids = valid_sampled_token_ids[req_idx]
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 num_sampled_ids: int = (
sampled_ids.shape[0] if sampled_ids is not None else 0
)
if cu_num_accepted_tokens is not None: if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append( cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + num_sampled_ids cu_num_accepted_tokens[-1] + num_sampled_ids
) )
if not sampled_ids: if sampled_ids is None or num_sampled_ids == 0:
continue continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx] start_idx = self.input_batch.num_tokens_no_spec[req_idx]
@ -2761,7 +2769,9 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: sample"): with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata) sampler_output = self._sample(logits, spec_decode_metadata)
def propose_draft_token_ids(sampled_token_ids): def propose_draft_token_ids(
sampled_token_ids: torch.Tensor | list[np.ndarray],
) -> None:
assert spec_decode_common_attn_metadata is not None assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("gpu_model_runner: draft"): with record_function_or_nullcontext("gpu_model_runner: draft"):
self._draft_token_ids = self.propose_draft_token_ids( self._draft_token_ids = self.propose_draft_token_ids(
@ -2883,14 +2893,14 @@ class GPUModelRunner(
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
sampled_token_ids: torch.Tensor | list[list[int]], sampled_token_ids: torch.Tensor | list[np.ndarray],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None, aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None, spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> list[list[int]] | torch.Tensor: ) -> torch.Tensor | list[list[int]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
@ -2922,7 +2932,7 @@ class GPUModelRunner(
for num_draft, tokens in zip( for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens, sampled_token_ids spec_decode_metadata.num_draft_tokens, sampled_token_ids
): ):
indices.append(offset + len(tokens) - 1) indices.append(offset + tokens.shape[0] - 1)
offset += num_draft + 1 offset += num_draft + 1
indices = torch.tensor(indices, device=self.device) indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices] hidden_states = sample_hidden_states[indices]
@ -4862,7 +4872,7 @@ class GPUModelRunner(
return kv_cache_spec return kv_cache_spec
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]:
# This is a short term mitigation for issue mentioned in # This is a short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/22754. # https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a cuda wise stream sync, which # `tolist` would trigger a cuda wise stream sync, which
@ -4875,4 +4885,4 @@ class GPUModelRunner(
pinned.copy_(sampled_token_ids, non_blocking=True) pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record() self.transfer_event.record()
self.transfer_event.synchronize() self.transfer_event.synchronize()
return pinned.tolist() return [row for row in pinned.numpy()]

View File

@ -1254,13 +1254,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_gen_len = selected_token_ids.shape[-1] max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1: if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist() valid_sampled_token_ids: list[np.ndarray] = [
row for row in selected_token_ids.numpy()
]
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
# TODO: Keep in sync with gpu_model_runner.py, in particular # TODO: Keep in sync with gpu_model_runner.py, in particular
# the "else" case here # the "else" case here
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i] = np.array([])
# Append sampled tokens # Append sampled tokens
for i, req_state, seq_len in request_seq_lens: for i, req_state, seq_len in request_seq_lens:
@ -1273,7 +1275,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
valid_mask = selected_token_ids != INVALID_TOKEN_ID valid_mask = selected_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist() gen_lens = valid_mask.sum(dim=1).tolist()
valid_sampled_token_ids = [ valid_sampled_token_ids = [
seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) seq.numpy() for seq in selected_token_ids[valid_mask].split(gen_lens)
] ]
self.input_batch.num_tokens[:num_reqs] += gen_lens self.input_batch.num_tokens[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens: for i, req_state, seq_len in request_seq_lens: