mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 16:06:22 +08:00
[Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)
Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
parent
89579a201f
commit
f942efb5a3
@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(
|
||||
@pytest.mark.parametrize(
|
||||
"which_tokens_accepted",
|
||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
def test_correct_output_format(which_tokens_accepted: str,
|
||||
disable_bonus_tokens: bool, seed: int,
|
||||
device: str):
|
||||
"""Verify the output has correct format given predetermined accepted matrix.
|
||||
"""
|
||||
@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens)
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||
accepted,
|
||||
@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
# Bonus tokens are currently disabled. Verify they're set to -1.
|
||||
expected_bonus_token_ids = bonus_token_ids.clone()
|
||||
# If bonus tokens disabled. Verify they are set to -1.
|
||||
# See https://github.com/vllm-project/vllm/issues/4212
|
||||
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
|
||||
if disable_bonus_tokens:
|
||||
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
# Expect all tokens to be equal to draft tokens.
|
||||
|
||||
@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("output_len", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify greedy equality when all sequences disable speculation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
||||
@ -57,7 +57,7 @@ from .conftest import run_greedy_equality_correctness_test
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
|
||||
77
tests/spec_decode/test_dynamic_spec_decode.py
Normal file
77
tests/spec_decode/test_dynamic_spec_decode.py
Normal file
@ -0,0 +1,77 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .utils import create_batch, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('queue_size', [2, 4])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6])
|
||||
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10])
|
||||
@torch.inference_mode()
|
||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
|
||||
"""Verify that speculative tokens are disabled when the batch size
|
||||
exceeds the threshold.
|
||||
"""
|
||||
disable_by_batch_size = 3
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
rejection_sampler=rejection_sampler,
|
||||
metrics_collector=metrics_collector,
|
||||
disable_by_batch_size=disable_by_batch_size)
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
running_queue_size=queue_size)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
# When the batch size is larger than the threshold,
|
||||
# we expect no speculative tokens (0).
|
||||
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
|
||||
assert seq_group_metadata_list[
|
||||
0].num_speculative_tokens == expected_num_spec_tokens
|
||||
|
||||
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device='cpu', # not used
|
||||
vocab_size=100, # not used
|
||||
# Must be long enough to avoid being skipped due to length.
|
||||
max_proposal_len=1024,
|
||||
)
|
||||
|
||||
if queue_size < disable_by_batch_size:
|
||||
# Should raise exception when executing the mocked draft model.
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
else:
|
||||
# Should not execute the draft model because spec decode is disabled
|
||||
# for all requests. Accordingly, the proposal length should be 0.
|
||||
proposals = proposer.get_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
assert proposals.proposal_lens.tolist() == [0] * batch_size
|
||||
@ -692,6 +692,7 @@ class SpeculativeConfig:
|
||||
speculative_max_model_len: Optional[int],
|
||||
enable_chunked_prefill: bool,
|
||||
use_v2_block_manager: bool,
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
@ -720,6 +721,9 @@ class SpeculativeConfig:
|
||||
use_v2_block_manager (bool): Whether vLLM is configured to use the
|
||||
v2 block manager or not. Used for raising an error since the v2
|
||||
block manager is required with spec decode.
|
||||
speculative_disable_by_batch_size (Optional[int]): Disable
|
||||
speculative decoding for new incoming requests when the number
|
||||
of enqueue requests is larger than this value, if provided.
|
||||
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
|
||||
window, if provided.
|
||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
||||
@ -730,7 +734,7 @@ class SpeculativeConfig:
|
||||
the necessary conditions are met, else None.
|
||||
"""
|
||||
|
||||
if (speculative_model is None and num_speculative_tokens is None):
|
||||
if speculative_model is None and num_speculative_tokens is None:
|
||||
return None
|
||||
|
||||
if speculative_model is not None and num_speculative_tokens is None:
|
||||
@ -739,6 +743,12 @@ class SpeculativeConfig:
|
||||
"num_speculative_tokens to be provided, but found "
|
||||
f"{speculative_model=} and {num_speculative_tokens=}.")
|
||||
|
||||
if (speculative_disable_by_batch_size is not None
|
||||
and speculative_disable_by_batch_size < 2):
|
||||
raise ValueError("Expect the batch size threshold of disabling "
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{speculative_disable_by_batch_size=}")
|
||||
|
||||
assert (speculative_model is not None
|
||||
and num_speculative_tokens is not None)
|
||||
|
||||
@ -807,6 +817,7 @@ class SpeculativeConfig:
|
||||
draft_model_config,
|
||||
draft_parallel_config,
|
||||
num_speculative_tokens,
|
||||
speculative_disable_by_batch_size,
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min,
|
||||
)
|
||||
@ -876,8 +887,9 @@ class SpeculativeConfig:
|
||||
draft_model_config: ModelConfig,
|
||||
draft_parallel_config: ParallelConfig,
|
||||
num_speculative_tokens: int,
|
||||
ngram_prompt_lookup_max: int,
|
||||
ngram_prompt_lookup_min: int,
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
@ -886,12 +898,19 @@ class SpeculativeConfig:
|
||||
draft_parallel_config: ParallelConfig for the draft model.
|
||||
num_speculative_tokens: The number of tokens to sample from the
|
||||
draft model before scoring with the target model.
|
||||
speculative_disable_by_batch_size: Disable speculative
|
||||
decoding for new incoming requests when the number of
|
||||
enqueue requests is larger than this value.
|
||||
ngram_prompt_lookup_max: Max size of ngram token window.
|
||||
ngram_prompt_lookup_min: Min size of ngram token window.
|
||||
"""
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||
self.speculative_disable_by_batch_size = \
|
||||
speculative_disable_by_batch_size
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
||||
|
||||
self._verify_args()
|
||||
|
||||
|
||||
@ -83,6 +83,7 @@ class EngineArgs:
|
||||
speculative_model: Optional[str] = None
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
speculative_max_model_len: Optional[int] = None
|
||||
speculative_disable_by_batch_size: Optional[int] = None
|
||||
ngram_prompt_lookup_max: Optional[int] = None
|
||||
ngram_prompt_lookup_min: Optional[int] = None
|
||||
|
||||
@ -467,6 +468,13 @@ class EngineArgs:
|
||||
'draft model. Sequences over this length will skip '
|
||||
'speculation.')
|
||||
|
||||
parser.add_argument(
|
||||
'--speculative-disable-by-batch-size',
|
||||
type=int,
|
||||
default=EngineArgs.speculative_disable_by_batch_size,
|
||||
help='Disable speculative decoding for new incoming requests '
|
||||
'if the number of enqueue requests is larger than this value.')
|
||||
|
||||
parser.add_argument(
|
||||
'--ngram-prompt-lookup-max',
|
||||
type=int,
|
||||
@ -547,6 +555,8 @@ class EngineArgs:
|
||||
target_dtype=self.dtype,
|
||||
speculative_model=self.speculative_model,
|
||||
num_speculative_tokens=self.num_speculative_tokens,
|
||||
speculative_disable_by_batch_size=self.
|
||||
speculative_disable_by_batch_size,
|
||||
speculative_max_model_len=self.speculative_max_model_len,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
use_v2_block_manager=self.use_v2_block_manager,
|
||||
|
||||
@ -93,6 +93,8 @@ class GPUExecutor(ExecutorBase):
|
||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||
scorer_worker=target_worker,
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
disable_by_batch_size=self.speculative_config.
|
||||
speculative_disable_by_batch_size,
|
||||
)
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
|
||||
@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
|
||||
https://arxiv.org/pdf/2302.01318.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self, strict_mode: bool = False):
|
||||
def __init__(self,
|
||||
disable_bonus_tokens: bool = True,
|
||||
strict_mode: bool = False):
|
||||
"""Create a rejection sampler.
|
||||
|
||||
Args:
|
||||
disable_bonus_tokens: Whether or not to disable the bonus token.
|
||||
Require when bonus tokens will cause corrupt KV cache for
|
||||
proposal methods that require KV cache.
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
"""
|
||||
super().__init__()
|
||||
self._disable_bonus_tokens = disable_bonus_tokens
|
||||
self._strict_mode = strict_mode
|
||||
|
||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||
@ -312,7 +318,8 @@ class RejectionSampler(nn.Module):
|
||||
# proposal methods that require KV cache. We can fix it by "prefilling"
|
||||
# the bonus token in the proposer. The following issue tracks the fix.
|
||||
# https://github.com/vllm-project/vllm/issues/4212
|
||||
output_with_bonus_tokens[:, -1] = -1
|
||||
if self._disable_bonus_tokens:
|
||||
output_with_bonus_tokens[:, -1] = -1
|
||||
|
||||
# Fill the recovered token ids.
|
||||
output.mul_(~after_false_mask).add_(
|
||||
|
||||
@ -612,6 +612,12 @@ class SequenceGroupMetadata:
|
||||
self._token_chunk_size = token_chunk_size
|
||||
self.do_sample = do_sample
|
||||
|
||||
# The number of speculative tokens adopted in this request.
|
||||
# None means specuative decoding is not used.
|
||||
# Zero means speculative decoding is disabled for some reasons.
|
||||
# TODO: We should maintain this states out of the sequence group.
|
||||
self.num_speculative_tokens = None
|
||||
|
||||
if self._token_chunk_size is None:
|
||||
if is_prompt:
|
||||
self._token_chunk_size = list(seq_data.values())[0].get_len()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
def create_worker(
|
||||
cls,
|
||||
scorer_worker: WorkerBase,
|
||||
draft_worker_kwargs,
|
||||
draft_worker_kwargs: Dict[str, Any],
|
||||
disable_by_batch_size: Optional[int],
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
ngram_prompt_lookup_max = (
|
||||
@ -62,7 +63,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
|
||||
disable_bonus_tokens = True
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
@ -75,9 +78,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
# TODO(cade) disable strict mode for speedup.
|
||||
rejection_sampler=RejectionSampler(strict_mode=True),
|
||||
)
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
rejection_sampler=RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens, ))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -85,6 +88,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
scorer_worker: WorkerBase,
|
||||
rejection_sampler: RejectionSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Create a SpecDecodeWorker.
|
||||
@ -97,11 +101,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
Worker.
|
||||
rejection_sampler: A Torch module used to perform modified rejection
|
||||
sampling for speculative decoding.
|
||||
disable_by_batch_size: If the batch size is larger than this,
|
||||
disable speculative decoding for new incoming requests.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
for testing purposes.
|
||||
"""
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||
self.rejection_sampler = rejection_sampler
|
||||
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
@ -199,27 +206,41 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
|
||||
# When the batch size is too large, disable speculative decoding
|
||||
# to stop trading off throughput for latency.
|
||||
disable_all = (execute_model_req.running_queue_size >=
|
||||
self.disable_by_batch_size)
|
||||
if disable_all:
|
||||
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||
# Once num_speculative_tokens is set to 0, the spec decode
|
||||
# of this request will be disabled forever.
|
||||
# TODO(comaniac): We currently store spec decoding specific
|
||||
# state in the global data structure, but we should maintain
|
||||
# this state within spec decode worker.
|
||||
seq_group_metadata.num_speculative_tokens = 0
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
# This happens for prefill, or when the spec decode is disabled
|
||||
# for this batch.
|
||||
if execute_model_req.num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(execute_model_req)
|
||||
return self._run_no_spec(execute_model_req,
|
||||
skip_proposer=disable_all)
|
||||
|
||||
return self._run_speculative_decoding_step(execute_model_req)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Run a prefill step, without any speculation. The input is sent to the
|
||||
proposer and scorer model so that the KV cache is consistent between the
|
||||
two.
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
skip_proposer: bool) -> List[SamplerOutput]:
|
||||
"""Run a prefill step, without any speculation. The input is sent to
|
||||
the proposer and scorer model so that the KV cache is consistent
|
||||
between the two. When skip_proposer is True, the proposer model is
|
||||
not called, meaning that the kv-cache in proposer for requests is not
|
||||
updated, so they cannot enable spec decode in the rest decoding.
|
||||
"""
|
||||
#logger.info("run proposer worker no spec")
|
||||
if not skip_proposer:
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
#logger.info("run target worker no spec")
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
@ -244,22 +265,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
sequence.
|
||||
"""
|
||||
|
||||
#logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
|
||||
#logger.info("score proposals")
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
execute_model_req,
|
||||
proposals,
|
||||
)
|
||||
|
||||
#logger.info("verify proposals")
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
proposals, execute_model_req.num_lookahead_slots)
|
||||
|
||||
#logger.info("create output list")
|
||||
return self._create_output_sampler_list(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
accepted_token_ids,
|
||||
|
||||
@ -56,7 +56,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
|
||||
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
@ -97,17 +97,27 @@ class Top1Proposer(SpeculativeProposer):
|
||||
|
||||
return proposals
|
||||
|
||||
def _split_by_max_model_len(
|
||||
def _split_by_proposal_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Determine which sequences would exceed the max model length."""
|
||||
"""Split sequences by two groups:
|
||||
1. Sequences with non-zero proposal length.
|
||||
2. Sequences with zero proposal length (due to disabled speculation
|
||||
or exceed the maximum model length).
|
||||
"""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
# The speculative decoding for this request has been disabled
|
||||
# (e.g. due to high traffic).
|
||||
if seq_group_metadata.num_speculative_tokens == 0:
|
||||
proposal_lens.append(0)
|
||||
continue
|
||||
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
@ -115,13 +125,14 @@ class Top1Proposer(SpeculativeProposer):
|
||||
# are supported.
|
||||
# If max_proposal_len is defined, then we shall no exccess this
|
||||
# quota for nonzero_proposal
|
||||
new_k = 0
|
||||
if (self.max_proposal_len is None
|
||||
or seq_len + proposal_len < self.max_proposal_len):
|
||||
proposal_lens.append(proposal_len)
|
||||
new_k = proposal_len
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
else:
|
||||
proposal_lens.append(0)
|
||||
proposal_lens.append(new_k)
|
||||
seq_group_metadata.num_speculative_tokens = new_k
|
||||
|
||||
return (
|
||||
proposal_lens,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user