mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
Revert "[Core] Performance: Use list[np.ndarray] instead of list[list… (#28773)
This commit is contained in:
parent
edfe498189
commit
ac86bff8cb
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -22,7 +21,7 @@ def _make_model_runner_output(
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)},
|
||||
sampled_token_ids=[np.array([i]) for i in range(len(req_ids))],
|
||||
sampled_token_ids=[[i] for i in range(len(req_ids))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import dataclasses
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -170,7 +169,7 @@ def test_schedule_partial_requests():
|
||||
req_id_to_index=req_to_index,
|
||||
# Only the first request has a sampled token id because
|
||||
# the rest requests are still being prefilled.
|
||||
sampled_token_ids=[np.array([0]), np.array([]), np.array([])],
|
||||
sampled_token_ids=[[0], [], []],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -217,7 +216,7 @@ def test_no_mm_input_chunking():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([]) for _ in range(len(requests))],
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -277,7 +276,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([]) for _ in range(len(requests))],
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -301,8 +300,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([0]), np.array([0])]
|
||||
+ [np.array([]) for _ in range(len(requests) - 2)],
|
||||
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -349,8 +347,8 @@ def test_stop_via_update_from_output():
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[
|
||||
np.array([EOS_TOKEN_ID]),
|
||||
np.array([10, 11]),
|
||||
[EOS_TOKEN_ID],
|
||||
[10, 11],
|
||||
], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@ -394,10 +392,7 @@ def test_stop_via_update_from_output():
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[
|
||||
np.array([10, 42, 12]),
|
||||
np.array([13, 14]),
|
||||
], # First request hits stop token
|
||||
sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -441,10 +436,7 @@ def test_stop_via_update_from_output():
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[
|
||||
np.array([10, 11, 12]),
|
||||
np.array([13]),
|
||||
], # First request exceeds max_tokens
|
||||
sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -483,7 +475,7 @@ def test_stop_via_update_from_output():
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])],
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -624,7 +616,7 @@ def test_schedule_concurrent_batches(
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -641,7 +633,7 @@ def test_schedule_concurrent_batches(
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -678,7 +670,7 @@ def test_preempt_during_execution():
|
||||
model_runner_output0 = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -695,7 +687,7 @@ def test_preempt_during_execution():
|
||||
model_runner_output1 = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[np.array([42])],
|
||||
sampled_token_ids=[[42]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -712,18 +704,14 @@ def test_preempt_during_execution():
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
[
|
||||
([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match
|
||||
([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch
|
||||
(
|
||||
[[1, 2], [3]],
|
||||
[np.array([1, 2, 5]), np.array([3, 4])],
|
||||
(2, 3, 3, [2, 1]),
|
||||
), # multiple sequences
|
||||
([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence
|
||||
([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence
|
||||
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
|
||||
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
|
||||
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences
|
||||
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
|
||||
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
|
||||
(
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[np.array([1, 2, 7]), np.array([4, 8])],
|
||||
[[1, 2, 7], [4, 8]],
|
||||
(2, 6, 3, [2, 1, 0]),
|
||||
), # multiple mismatches
|
||||
],
|
||||
@ -757,7 +745,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([0]) for _ in range(len(requests))],
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -984,7 +972,7 @@ def test_kv_connector_basic(is_async: bool):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1037,7 +1025,7 @@ def test_kv_connector_basic(is_async: bool):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1100,7 +1088,7 @@ def test_external_prefix_cache_metrics():
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[r.request_id for r in requests],
|
||||
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
|
||||
sampled_token_ids=[np.array([1000])] * NUM_REQUESTS,
|
||||
sampled_token_ids=[[1000]] * NUM_REQUESTS,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1166,7 +1154,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1251,7 +1239,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1344,7 +1332,7 @@ def make_output(scheduler: Scheduler):
|
||||
return ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in scheduler.running],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)},
|
||||
sampled_token_ids=[np.array([1000])] * len(scheduler.running),
|
||||
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1761,7 +1749,7 @@ def test_priority_scheduling_preemption():
|
||||
req_id_to_index={
|
||||
req.request_id: i for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1830,7 +1818,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
|
||||
req_id_to_index={
|
||||
req.request_id: i for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -2076,7 +2064,7 @@ def test_priority_scheduling_heap_property():
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.req_id],
|
||||
req_id_to_index={req.req_id: 0},
|
||||
sampled_token_ids=[np.array([100])],
|
||||
sampled_token_ids=[[100]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -2162,7 +2150,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[request_low.request_id],
|
||||
req_id_to_index={request_low.request_id: 0},
|
||||
sampled_token_ids=[np.array([100])],
|
||||
sampled_token_ids=[[100]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@ -2193,7 +2181,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[np.array([100]) for _ in requests],
|
||||
sampled_token_ids=[[100] for _ in requests],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@ -2219,7 +2207,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[np.array([]), np.array([100])],
|
||||
sampled_token_ids=[[], [100]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
|
||||
@ -7,7 +7,6 @@ from dataclasses import dataclass
|
||||
from itertools import chain, count
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
@ -229,7 +228,7 @@ def create_model_runner_output(
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else token_id
|
||||
sampled_token_ids = [np.array([sampled_token]) for _ in req_ids]
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = (
|
||||
None
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -113,9 +112,7 @@ def test_prepare_next_token_ids():
|
||||
sampled_token_ids_tensor = torch.tensor(
|
||||
sampled_token_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
sampled_token_ids_cpu = [
|
||||
np.array([i for i in seq if i != -1]) for seq in sampled_token_ids
|
||||
]
|
||||
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(
|
||||
|
||||
@ -77,7 +77,7 @@ def test_ngram_proposer():
|
||||
# No match.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
@ -88,7 +88,7 @@ def test_ngram_proposer():
|
||||
# No match for 4-gram.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
@ -99,7 +99,7 @@ def test_ngram_proposer():
|
||||
# No match for 4-gram but match for 3-gram.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
@ -111,7 +111,7 @@ def test_ngram_proposer():
|
||||
# In this case, the proposer should return the 4-gram match.
|
||||
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
@ -122,7 +122,7 @@ def test_ngram_proposer():
|
||||
# Match for 2-gram and 3-gram, but not 4-gram.
|
||||
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
@ -133,7 +133,7 @@ def test_ngram_proposer():
|
||||
# Multiple 3-gram matched, but always pick the first one.
|
||||
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
@ -144,7 +144,7 @@ def test_ngram_proposer():
|
||||
# check empty input
|
||||
token_ids_cpu = np.array([[]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[np.array([0])],
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
@ -157,7 +157,7 @@ def test_ngram_proposer():
|
||||
# second request has 3 tokens and no match. Padded with -1 for max len 5
|
||||
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[np.array([0]), np.array([1])],
|
||||
sampled_token_ids=[[0], [1]],
|
||||
req_ids=["0", "1"],
|
||||
num_tokens_no_spec=np.array([5, 3]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
@ -181,7 +181,7 @@ def test_ngram_proposer():
|
||||
input_2[:3] = [4, 5, 6]
|
||||
token_ids_cpu = np.array([input_1, input_2])
|
||||
result = ngram_proposer.propose(
|
||||
sampled_token_ids=[np.array([0]), np.array([1])],
|
||||
sampled_token_ids=[[0], [1]],
|
||||
req_ids=["0", "1"],
|
||||
num_tokens_no_spec=np.array([len(input_1), 3]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
|
||||
@ -1010,8 +1010,8 @@ class Scheduler(SchedulerInterface):
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids: list[int] = (
|
||||
sampled_token_ids[req_index].tolist() if sampled_token_ids else []
|
||||
generated_token_ids = (
|
||||
sampled_token_ids[req_index] if sampled_token_ids else []
|
||||
)
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
|
||||
@ -158,7 +158,7 @@ class ModelRunnerOutput:
|
||||
# num_generated_tokens is the number of tokens
|
||||
# generated in the current step. It can be different for
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: list[np.ndarray]
|
||||
sampled_token_ids: list[list[int]]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
from dataclasses import replace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -205,7 +204,7 @@ class RejectionSampler(nn.Module):
|
||||
def parse_output(
|
||||
output_token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
) -> list[np.ndarray]:
|
||||
) -> list[list[int]]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
@ -221,7 +220,10 @@ class RejectionSampler(nn.Module):
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)]
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
def apply_logits_processors(
|
||||
self,
|
||||
|
||||
@ -484,7 +484,7 @@ class EagleProposer:
|
||||
|
||||
def prepare_next_token_ids_cpu(
|
||||
self,
|
||||
sampled_token_ids: list[np.ndarray],
|
||||
sampled_token_ids: list[list[int]],
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
@ -499,7 +499,7 @@ class EagleProposer:
|
||||
req_ids = gpu_input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids.shape[0] > 0:
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
@ -510,9 +510,10 @@ class EagleProposer:
|
||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
return torch.tensor(
|
||||
next_token_ids = torch.tensor(
|
||||
next_token_ids, dtype=torch.int32, device=self.input_ids.device
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
|
||||
@ -54,7 +54,7 @@ class NgramProposer:
|
||||
# Trigger Numba JIT compilation for N-gram proposer.
|
||||
# This usually takes less than 1 second.
|
||||
self.propose(
|
||||
[np.array([])] * 1024,
|
||||
[[]] * 1024,
|
||||
[""] * 1024,
|
||||
np.zeros(1024, dtype=np.int32),
|
||||
np.zeros((1024, self.max_model_len), dtype=np.int32),
|
||||
@ -131,7 +131,7 @@ class NgramProposer:
|
||||
|
||||
def propose(
|
||||
self,
|
||||
sampled_token_ids: list[np.ndarray],
|
||||
sampled_token_ids: list[list[int]],
|
||||
req_ids: list[str],
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
@ -140,7 +140,7 @@ class NgramProposer:
|
||||
# find which requests need ngram proposals
|
||||
valid_ngram_requests = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = sampled_ids.shape[0]
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
continue
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
@ -34,16 +32,16 @@ class SuffixDecodingProposer:
|
||||
def propose(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
sampled_token_ids: list[np.ndarray],
|
||||
sampled_token_ids: list[list[int]],
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Propose speculative tokens for each request in the input batch. Suffix Decoding
|
||||
will speculate a dynamic number of tokens for each request every decoding step,
|
||||
so each entry in the returned list may have different lengths.
|
||||
"""
|
||||
draft_token_ids: list[np.ndarray] = []
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
if sampled_ids.shape[0] == 0:
|
||||
if not sampled_ids:
|
||||
# Skip speculative decoding for partial prefills.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
@ -72,7 +70,7 @@ class SuffixDecodingProposer:
|
||||
self.suffix_cache.start_request(req_id, prompt_token_ids)
|
||||
|
||||
# Append the newly sampled ids to the suffix cache for this request.
|
||||
self.suffix_cache.add_active_response(req_id, sampled_ids.tolist())
|
||||
self.suffix_cache.add_active_response(req_id, sampled_ids)
|
||||
|
||||
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
|
||||
# we extract the pattern from the end of the input.
|
||||
|
||||
@ -216,11 +216,9 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
del self._logprobs_tensors
|
||||
del self._sampled_token_ids
|
||||
|
||||
valid_sampled_token_ids: list[np.ndarray] = [
|
||||
row for row in self.sampled_token_ids_cpu.numpy()
|
||||
]
|
||||
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
|
||||
for i in self._invalid_req_indices:
|
||||
valid_sampled_token_ids[i] = np.array([])
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
output = self._model_runner_output
|
||||
output.sampled_token_ids = valid_sampled_token_ids
|
||||
@ -2341,7 +2339,7 @@ class GPUModelRunner(
|
||||
) -> tuple[
|
||||
dict[str, int],
|
||||
LogprobsLists | None,
|
||||
list[np.ndarray],
|
||||
list[list[int]],
|
||||
dict[str, LogprobsTensors | None],
|
||||
list[str],
|
||||
dict[str, int],
|
||||
@ -2367,7 +2365,6 @@ class GPUModelRunner(
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
invalid_req_indices = []
|
||||
valid_sampled_token_ids: list[np.ndarray]
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
@ -2382,7 +2379,7 @@ class GPUModelRunner(
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[int(i)] = np.array([])
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
else:
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
||||
@ -2410,24 +2407,19 @@ class GPUModelRunner(
|
||||
[0] if spec_decode_metadata and logprobs_tensors else None
|
||||
)
|
||||
for req_idx in range(num_sampled_tokens):
|
||||
sampled_ids: np.ndarray | None
|
||||
if self.use_async_scheduling:
|
||||
sampled_ids = (
|
||||
np.array([-1]) if req_idx not in invalid_req_indices_set else None
|
||||
)
|
||||
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
|
||||
else:
|
||||
sampled_ids = valid_sampled_token_ids[req_idx]
|
||||
|
||||
num_sampled_ids: int = (
|
||||
sampled_ids.shape[0] if sampled_ids is not None else 0
|
||||
)
|
||||
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
|
||||
|
||||
if cu_num_accepted_tokens is not None:
|
||||
cu_num_accepted_tokens.append(
|
||||
cu_num_accepted_tokens[-1] + num_sampled_ids
|
||||
)
|
||||
|
||||
if sampled_ids is None or num_sampled_ids == 0:
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
@ -2769,9 +2761,7 @@ class GPUModelRunner(
|
||||
with record_function_or_nullcontext("gpu_model_runner: sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
def propose_draft_token_ids(
|
||||
sampled_token_ids: torch.Tensor | list[np.ndarray],
|
||||
) -> None:
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
with record_function_or_nullcontext("gpu_model_runner: draft"):
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
@ -2893,14 +2883,14 @@ class GPUModelRunner(
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
sampled_token_ids: torch.Tensor | list[np.ndarray],
|
||||
sampled_token_ids: torch.Tensor | list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
spec_decode_metadata: SpecDecodeMetadata | None,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> torch.Tensor | list[list[int]]:
|
||||
) -> list[list[int]] | torch.Tensor:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
@ -2932,7 +2922,7 @@ class GPUModelRunner(
|
||||
for num_draft, tokens in zip(
|
||||
spec_decode_metadata.num_draft_tokens, sampled_token_ids
|
||||
):
|
||||
indices.append(offset + tokens.shape[0] - 1)
|
||||
indices.append(offset + len(tokens) - 1)
|
||||
offset += num_draft + 1
|
||||
indices = torch.tensor(indices, device=self.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
@ -4872,7 +4862,7 @@ class GPUModelRunner(
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]:
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||
# This is a short term mitigation for issue mentioned in
|
||||
# https://github.com/vllm-project/vllm/issues/22754.
|
||||
# `tolist` would trigger a cuda wise stream sync, which
|
||||
@ -4885,4 +4875,4 @@ class GPUModelRunner(
|
||||
pinned.copy_(sampled_token_ids, non_blocking=True)
|
||||
self.transfer_event.record()
|
||||
self.transfer_event.synchronize()
|
||||
return [row for row in pinned.numpy()]
|
||||
return pinned.tolist()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user