mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 11:41:50 +08:00
[Spec Decode] Add Batch Parallel Ngram. Upto 8x lower overhead. (#24986)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b558c3a8b7
commit
f3a478b55e
@ -1,17 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
def main(args):
|
||||
def benchmark_propose(args):
|
||||
rows = []
|
||||
for max_ngram in args.max_ngram:
|
||||
collector = TimeCollector(TimeCollector.US)
|
||||
@ -69,10 +83,88 @@ def main(args):
|
||||
)
|
||||
|
||||
|
||||
def benchmark_batched_propose(args):
|
||||
NUM_SPECULATIVE_TOKENS_NGRAM = 10
|
||||
PROMPT_LOOKUP_MIN = 5
|
||||
PROMPT_LOOKUP_MAX = 15
|
||||
MAX_MODEL_LEN = int(1e7)
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
model_config = ModelConfig(model="facebook/opt-125m", runner="generate")
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
method="ngram",
|
||||
num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
|
||||
prompt_lookup_max=PROMPT_LOOKUP_MAX,
|
||||
prompt_lookup_min=PROMPT_LOOKUP_MIN,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig(),
|
||||
)
|
||||
|
||||
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
|
||||
mock_pp_group = mock.MagicMock()
|
||||
mock_pp_group.world_size = 1
|
||||
with mock.patch(
|
||||
"vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
|
||||
):
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
|
||||
# hack max model len
|
||||
runner.max_model_len = MAX_MODEL_LEN
|
||||
runner.drafter.max_model_len = MAX_MODEL_LEN
|
||||
|
||||
dummy_input_batch = InputBatch(
|
||||
max_num_reqs=args.num_req,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_batched_tokens=args.num_req * args.num_token,
|
||||
device=DEVICE,
|
||||
pin_memory=False,
|
||||
vocab_size=256000,
|
||||
block_sizes=[16],
|
||||
)
|
||||
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
|
||||
dummy_input_batch.spec_decode_unsupported_reqs = ()
|
||||
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
|
||||
dummy_input_batch.token_ids_cpu = np.random.randint(
|
||||
0, 20, (args.num_req, args.num_token)
|
||||
)
|
||||
|
||||
runner.input_batch = dummy_input_batch
|
||||
|
||||
sampled_token_ids = [[0]] * args.num_req
|
||||
|
||||
print("Starting benchmark")
|
||||
# first run is warmup so ignore it
|
||||
for _ in range(args.num_iteration):
|
||||
start = time.time()
|
||||
runner.drafter.propose(
|
||||
sampled_token_ids,
|
||||
dummy_input_batch.req_ids,
|
||||
dummy_input_batch.num_tokens_no_spec,
|
||||
dummy_input_batch.token_ids_cpu,
|
||||
dummy_input_batch.spec_decode_unsupported_reqs,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Iteration time (s): {end - start}")
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of N-gram speculative decode drafting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batched", action="store_true", help="consider time to prepare batch"
|
||||
) # noqa: E501
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
@ -105,8 +197,17 @@ def invoke_main() -> None:
|
||||
help="Number of speculative tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
if not args.batched:
|
||||
benchmark_propose(args)
|
||||
else:
|
||||
benchmark_batched_propose(args)
|
||||
|
||||
|
||||
"""
|
||||
# Example command lines:
|
||||
# time python3 benchmarks/benchmark_ngram_proposer.py
|
||||
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
|
||||
""" # noqa: E501
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
|
||||
@ -9,11 +9,13 @@ from vllm.v1.spec_decode.ngram_proposer import (
|
||||
|
||||
def test_find_longest_matched_ngram_and_propose_tokens():
|
||||
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
||||
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=2) is None
|
||||
result = _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=tokens,
|
||||
min_ngram=2,
|
||||
max_ngram=2,
|
||||
max_model_len=1024,
|
||||
k=2)
|
||||
assert len(result) == 0
|
||||
|
||||
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
|
||||
np.testing.assert_array_equal(
|
||||
@ -62,7 +64,7 @@ def test_find_longest_matched_ngram_and_propose_tokens():
|
||||
|
||||
def test_ngram_proposer():
|
||||
|
||||
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
|
||||
def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
|
||||
# Dummy model config. Just to set max_model_len.
|
||||
model_config = ModelConfig(model="facebook/opt-125m")
|
||||
return NgramProposer(
|
||||
@ -75,36 +77,120 @@ def test_ngram_proposer():
|
||||
)))
|
||||
|
||||
# No match.
|
||||
result = ngram_proposer(
|
||||
min_n=2, max_n=2,
|
||||
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
|
||||
assert result is None
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
# No match for 4-gram.
|
||||
result = ngram_proposer(
|
||||
min_n=4, max_n=4,
|
||||
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||
assert result is None
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
# No match for 4-gram but match for 3-gram.
|
||||
result = ngram_proposer(
|
||||
min_n=3, max_n=4,
|
||||
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
|
||||
assert np.array_equal(result, np.array([4, 1]))
|
||||
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[4, 1]]))
|
||||
|
||||
# Match for both 4-gram and 3-gram.
|
||||
# In this case, the proposer should return the 4-gram match.
|
||||
result = ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
|
||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
||||
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
|
||||
|
||||
# Match for 2-gram and 3-gram, but not 4-gram.
|
||||
result = ngram_proposer(min_n=2, max_n=4, k=2).propose(
|
||||
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
|
||||
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
||||
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
|
||||
|
||||
# Multiple 3-gram matched, but always pick the first one.
|
||||
result = ngram_proposer(
|
||||
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array(
|
||||
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]))
|
||||
assert np.array_equal(result, np.array([100, 1]))
|
||||
token_ids_cpu = np.array(
|
||||
[[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
|
||||
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert np.array_equal(result, np.array([[100, 1]]))
|
||||
|
||||
# check empty input
|
||||
token_ids_cpu = np.array([[]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0]],
|
||||
req_ids=["0"],
|
||||
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 0
|
||||
|
||||
# check multibatch input
|
||||
# first request has 5 tokens and a match
|
||||
# second request has 3 tokens and no match. Padded with -1 for max len 5
|
||||
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
|
||||
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
|
||||
sampled_token_ids=[[0], [1]],
|
||||
req_ids=["0", "1"],
|
||||
num_tokens_no_spec=np.array([5, 3]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 2
|
||||
assert np.array_equal(result[0], np.array([3, 1]))
|
||||
assert np.array_equal(result[1], np.array([]))
|
||||
|
||||
# test if 0 threads available: can happen if TP size > CPU count
|
||||
ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
|
||||
ngram_proposer.num_numba_thread_available = 0
|
||||
# set max_model_len to 2 * threshold to ensure multithread is used
|
||||
num_tokens_threshold = ngram_proposer.num_tokens_threshold
|
||||
ngram_proposer.max_model_len = 2 * num_tokens_threshold
|
||||
# using multibatch test
|
||||
middle_integer = num_tokens_threshold // 2
|
||||
input_1 = [_ for _ in range(num_tokens_threshold)]
|
||||
input_1 += [middle_integer, middle_integer + 1]
|
||||
input_2 = [-1] * len(input_1)
|
||||
input_2[:3] = [4, 5, 6]
|
||||
token_ids_cpu = np.array([input_1, input_2])
|
||||
result = ngram_proposer.propose(
|
||||
sampled_token_ids=[[0], [1]],
|
||||
req_ids=["0", "1"],
|
||||
num_tokens_no_spec=np.array([len(input_1), 3]),
|
||||
token_ids_cpu=token_ids_cpu,
|
||||
spec_decode_unsupported_reqs=(),
|
||||
)
|
||||
assert len(result[0]) == 2
|
||||
assert np.array_equal(result[0],
|
||||
np.array([middle_integer + 2, middle_integer + 3]))
|
||||
assert np.array_equal(result[1], np.array([]))
|
||||
|
||||
@ -17,7 +17,7 @@ PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
MAX_SPEC_LEN = 128
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from numba import jit
|
||||
from numba import get_num_threads, jit, njit, prange, set_num_threads
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@ -26,55 +26,174 @@ class NgramProposer:
|
||||
# Maximum length of the model.
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Pre-allocate buffers for numba batch propose.
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k),
|
||||
dtype=np.int32)
|
||||
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
|
||||
|
||||
# Threshold of total number of tokens in the batch to enable
|
||||
# multi-threading in numba batch propose.
|
||||
self.num_tokens_threshold = 8192
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
cpu_count = os.cpu_count()
|
||||
# Max number of threads for numba parallel processing.
|
||||
if cpu_count:
|
||||
# Divide by 2 to use physical cores
|
||||
# and not logical cores (hyper-threading).
|
||||
# Cap the number of threads to 8 to avoid using too many threads
|
||||
# since other components like frontend (incl tokenization)
|
||||
# and Structured Outputs also use multiple threads.
|
||||
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
|
||||
# when TP parallelization for ngram is implemented.
|
||||
self.num_numba_thread_available = min(1, (cpu_count // 2))
|
||||
# Divide by tp_size to ensure each tensor parallel rank
|
||||
# has some threads since all ranks will run this.
|
||||
self.num_numba_thread_available //= tp_size
|
||||
else:
|
||||
self.num_numba_thread_available = 1
|
||||
|
||||
# Trigger Numba JIT compilation for N-gram proposer.
|
||||
# This usually takes less than 1 second.
|
||||
self.propose(np.zeros(1024, dtype=np.int32))
|
||||
self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32),
|
||||
np.zeros((1024, self.max_model_len), dtype=np.int32),
|
||||
set())
|
||||
|
||||
def batch_propose(
|
||||
self,
|
||||
num_requests: int,
|
||||
valid_ngram_requests: list,
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
) -> list[list[int]]:
|
||||
"""Batch version of ngram proposer using numba for acceleration.
|
||||
|
||||
Args:
|
||||
valid_ngram_requests:
|
||||
Set of indices of requests that need ngram proposals.
|
||||
num_tokens_no_spec:
|
||||
Numpy array of shape (batch_size,) representing the number
|
||||
of tokens without speculative tokens for each request.
|
||||
token_ids_cpu:
|
||||
Numpy array of shape (batch_size, max_model_len)
|
||||
representing the token IDs for each request.
|
||||
|
||||
Returns:
|
||||
list[list[int]]:
|
||||
A list where each element is a list of proposed
|
||||
token IDs for the corresponding request.
|
||||
"""
|
||||
draft_token_ids: list[list[int]] = []
|
||||
|
||||
# Only run batch propose if there are requests needing ngram proposals.
|
||||
# avoid calling numba function with empty list which causes error
|
||||
# ValueError: cannot compute fingerprint of empty list
|
||||
if num_ngram_requests := len(valid_ngram_requests):
|
||||
original_num_numba_threads = get_num_threads()
|
||||
# Ensure we use at least one thread.
|
||||
# If total tokens is small, using multiple threads
|
||||
# may slow down due to overhead.
|
||||
total_tokens = np.sum(num_tokens_no_spec)
|
||||
if total_tokens >= self.num_tokens_threshold:
|
||||
final_num_threads = max(
|
||||
1, min(self.num_numba_thread_available,
|
||||
num_ngram_requests))
|
||||
set_num_threads(final_num_threads)
|
||||
else:
|
||||
set_num_threads(1)
|
||||
|
||||
batch_propose_numba(valid_ngram_requests, num_tokens_no_spec,
|
||||
token_ids_cpu, self.min_n, self.max_n,
|
||||
self.max_model_len, self.k,
|
||||
self.valid_ngram_draft,
|
||||
self.valid_ngram_num_drafts)
|
||||
|
||||
# Restore original number of threads.
|
||||
set_num_threads(original_num_numba_threads)
|
||||
|
||||
for i in range(num_requests):
|
||||
if i in valid_ngram_requests and \
|
||||
self.valid_ngram_num_drafts[i] > 0:
|
||||
draft_token_ids.append(self.valid_ngram_draft[
|
||||
i, :self.valid_ngram_num_drafts[i]].tolist())
|
||||
else:
|
||||
draft_token_ids.append([])
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def propose(
|
||||
self,
|
||||
context_token_ids: np.ndarray,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Proposes the next sequence of tokens based on n-gram pattern
|
||||
matching in the context. The function finds matches of the last n
|
||||
tokens in the previous context, and returns k tokens that followed
|
||||
that match.
|
||||
|
||||
Args:
|
||||
context_token_ids: Numpy array of token IDs representing the
|
||||
context sequence.
|
||||
sampled_token_ids: list[list[int]],
|
||||
req_ids: list[str],
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
spec_decode_unsupported_reqs: set,
|
||||
) -> list[list[int]]:
|
||||
|
||||
Returns:
|
||||
np.ndarray: The sequence of tokens that followed
|
||||
the matched n-gram in the context.
|
||||
None: If no matching n-gram pattern is found.
|
||||
# find which requests need ngram proposals
|
||||
valid_ngram_requests = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
continue
|
||||
|
||||
Example:
|
||||
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
|
||||
k = 4:
|
||||
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
|
||||
- The last 2 tokens [2,3] will be matched against the previous
|
||||
4 tokens [1,2,3,4].
|
||||
- Finding a match of [2,3] would return the tokens that
|
||||
followed that pattern. Here we will return [4,2,3] because
|
||||
we only have three tokens after the match.
|
||||
"""
|
||||
# TODO(woosuk): Optimize this.
|
||||
return _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=context_token_ids,
|
||||
min_ngram=self.min_n,
|
||||
max_ngram=self.max_n,
|
||||
max_model_len=self.max_model_len,
|
||||
k=self.k)
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = req_ids[i]
|
||||
if req_id in spec_decode_unsupported_reqs:
|
||||
continue
|
||||
|
||||
num_tokens = num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
continue
|
||||
|
||||
valid_ngram_requests.append(i)
|
||||
|
||||
draft_token_ids = self.batch_propose(
|
||||
len(sampled_token_ids),
|
||||
valid_ngram_requests,
|
||||
num_tokens_no_spec,
|
||||
token_ids_cpu,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
pass
|
||||
|
||||
|
||||
@njit(parallel=True)
|
||||
def batch_propose_numba(valid_ngram_requests: list,
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray, min_n: int, max_n: int,
|
||||
max_model_len: int, k: int,
|
||||
valid_ngram_draft: np.ndarray,
|
||||
valid_ngram_num_drafts: np.ndarray):
|
||||
for i in prange(len(valid_ngram_requests)):
|
||||
idx = valid_ngram_requests[i]
|
||||
num_tokens = num_tokens_no_spec[idx]
|
||||
context_token_ids = token_ids_cpu[idx, :num_tokens]
|
||||
drafter_output = _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=context_token_ids,
|
||||
min_ngram=min_n,
|
||||
max_ngram=max_n,
|
||||
max_model_len=max_model_len,
|
||||
k=k)
|
||||
|
||||
valid_ngram_num_drafts[i] = drafter_output.shape[0]
|
||||
if len(drafter_output):
|
||||
valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens: np.ndarray, min_ngram: int, max_ngram: int,
|
||||
max_model_len: int, k: int) -> Optional[np.ndarray]:
|
||||
def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray,
|
||||
min_ngram: int,
|
||||
max_ngram: int,
|
||||
max_model_len: int,
|
||||
k: int) -> np.ndarray:
|
||||
"""
|
||||
Find the longest n-gram which matches the suffix of the given tokens
|
||||
whose length is within [min_ngram, max_ngram] (inclusive).
|
||||
@ -84,12 +203,12 @@ def _find_longest_matched_ngram_and_propose_tokens(
|
||||
# Do not generate draft tokens is context is shorter than minimum n-gram
|
||||
total_token = origin_tokens.shape[0]
|
||||
if total_token < min_ngram:
|
||||
return None
|
||||
return np.empty((0, ), dtype=origin_tokens.dtype)
|
||||
|
||||
# Do not generate draft tokens beyond the max model length.
|
||||
k = min(k, max_model_len - total_token)
|
||||
if k <= 0:
|
||||
return None
|
||||
return np.empty((0, ), dtype=origin_tokens.dtype)
|
||||
|
||||
# Flip tokens, and the goal become to find longest ngram
|
||||
# on the rightmost position which matches the prefix with
|
||||
@ -146,7 +265,7 @@ def _find_longest_matched_ngram_and_propose_tokens(
|
||||
|
||||
if longest_ngram < min_ngram:
|
||||
# No valid ngram is found
|
||||
return None
|
||||
return np.empty((0, ), dtype=origin_tokens.dtype)
|
||||
|
||||
# Flip the position back, so in origin_tokens,
|
||||
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
|
||||
|
||||
@ -2404,8 +2404,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
draft_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
sampled_token_ids, self.input_batch.req_ids,
|
||||
self.input_batch.num_tokens_no_spec,
|
||||
self.input_batch.token_ids_cpu,
|
||||
self.input_batch.spec_decode_unsupported_reqs)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
@ -2515,41 +2518,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
return draft_token_ids
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
req_ids = self.input_batch.req_ids
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = req_ids[i]
|
||||
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
num_tokens = self.input_batch.num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :num_tokens])
|
||||
if drafter_output is None or len(drafter_output) == 0:
|
||||
draft_token_ids.append([])
|
||||
else:
|
||||
draft_token_ids.append(drafter_output.tolist())
|
||||
return draft_token_ids
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
allowed_config_names = {"load_config", "model_config"}
|
||||
for config_name, config_overrides in overrides.items():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user