[BUGFIX] Raise an error for no draft token case when draft_tp>1 (#6369)

This commit is contained in:
Woo-Yeon Lee 2024-07-19 22:01:09 +09:00 committed by GitHub
parent 6366efc67b
commit a921e86392
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 85 additions and 5 deletions

View File

@ -58,3 +58,65 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 4,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify job failure with RuntimeError when all sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
TODO: fix it to pass without raising Error. (#5814)
"""
with pytest.raises(RuntimeError):
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

View File

@ -22,6 +22,9 @@ class SpeculativeProposals:
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor
# A flag to mark that there's no available proposals
no_proposals: bool = False
def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids}, "

View File

@ -109,6 +109,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha: float,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
@ -133,6 +134,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
allow_zero_draft_token_step = False
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
@ -155,10 +158,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))
return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler)
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
def __init__(
self,
@ -167,6 +172,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
spec_decode_sampler: SpecDecodeBaseSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
):
"""
Create a SpecDecodeWorker.
@ -187,11 +193,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
@ -461,6 +471,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,

View File

@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
no_proposals=maybe_sampler_output is None)
return proposals