[Dynamic Spec Decoding] Minor fix for disabling speculative decoding (#5000)

This commit is contained in:
Lily Liu 2024-05-25 10:00:14 -07:00 committed by GitHub
parent 325c119961
commit d5a1697772
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 11 deletions

View File

@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size, batch_size,
max_output_len=output_len, max_output_len=output_len,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
@ -13,9 +13,9 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import create_batch, mock_worker from .utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [2, 4]) @pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) @pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) @pytest.mark.parametrize('k', [1])
@torch.inference_mode() @torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
"""Verify that speculative tokens are disabled when the batch size """Verify that speculative tokens are disabled when the batch size
@ -42,7 +42,11 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
num_lookahead_slots=k, num_lookahead_slots=k,
running_queue_size=queue_size) running_queue_size=queue_size)
with pytest.raises(ValueError, match=exception_secret): if queue_size > disable_by_batch_size:
with patch.object(worker,
'_run_no_spec',
side_effect=ValueError(exception_secret)), \
pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req) worker.execute_model(execute_model_req=execute_model_req)
# When the batch size is larger than the threshold, # When the batch size is larger than the threshold,

View File

@ -273,10 +273,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._maybe_disable_speculative_tokens( self._maybe_disable_speculative_tokens(
disable_all_speculation, execute_model_req.seq_group_metadata_list) disable_all_speculation, execute_model_req.seq_group_metadata_list)
# If no spec tokens, call the proposer and scorer workers normally. # Speculative decoding is disabled in the following cases:
# Used for prefill. # 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
if num_lookahead_slots == 0 or len( if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0: execute_model_req.seq_group_metadata_list
) == 0 or disable_all_speculation:
return self._run_no_spec(execute_model_req, return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation) skip_proposer=disable_all_speculation)
@ -316,8 +323,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest, def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]: skip_proposer: bool) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to """Run a single generation step without any speculation. The input is
the proposer and scorer model so that the KV cache is consistent sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding. updated, so they cannot enable spec decode in the rest decoding.