mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 11:47:07 +08:00
[Bugfix] Fix for Spec model TP + Chunked Prefill (#10232)
Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> Signed-off-by: Sourashis Roy <sroy@roblox.com> Co-authored-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
parent
1f6584ee85
commit
db66e018ea
@ -118,7 +118,7 @@ Feature x Feature
|
|||||||
-
|
-
|
||||||
-
|
-
|
||||||
* - :ref:`SD <spec_decode>`
|
* - :ref:`SD <spec_decode>`
|
||||||
- ✗
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
- ✗
|
- ✗
|
||||||
- ✅
|
- ✅
|
||||||
|
|||||||
@ -413,6 +413,45 @@ def test_chunked_prefill_preempt():
|
|||||||
assert out.num_batched_tokens == max_num_batched_tokens
|
assert out.num_batched_tokens == max_num_batched_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_scheduler_steps", [1, 5])
|
||||||
|
def test_chunked_prefill_spec_prefill(num_scheduler_steps):
|
||||||
|
"""Verify that the num_lookahead_slots is set appropriately for an all"""
|
||||||
|
"""prefill batch depending on whether multi-step scheduling is enabled"""
|
||||||
|
"""or not"""
|
||||||
|
block_size = 4
|
||||||
|
max_seqs = 30
|
||||||
|
max_model_len = 200
|
||||||
|
max_num_batched_tokens = 30
|
||||||
|
num_lookahead_slots = 4
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
|
max_num_batched_tokens,
|
||||||
|
max_seqs,
|
||||||
|
max_model_len,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
|
num_scheduler_steps=num_scheduler_steps,
|
||||||
|
)
|
||||||
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
|
cache_config.num_cpu_blocks = 16
|
||||||
|
cache_config.num_gpu_blocks = 16
|
||||||
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
|
|
||||||
|
_, seq_group = create_dummy_prompt("1",
|
||||||
|
prompt_length=30,
|
||||||
|
block_size=block_size)
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
# The request is chunked.
|
||||||
|
# prefill scheduled now.
|
||||||
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
|
assert out.num_prefill_groups == 1
|
||||||
|
assert out.num_batched_tokens == max_num_batched_tokens
|
||||||
|
print(out.num_lookahead_slots)
|
||||||
|
assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else
|
||||||
|
num_lookahead_slots)
|
||||||
|
|
||||||
|
|
||||||
def test_chunked_prefill_max_seqs():
|
def test_chunked_prefill_max_seqs():
|
||||||
block_size = 4
|
block_size = 4
|
||||||
max_seqs = 2
|
max_seqs = 2
|
||||||
|
|||||||
@ -50,49 +50,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
|
|||||||
with pytest.raises(ValueError, match="cannot be larger than"):
|
with pytest.raises(ValueError, match="cannot be larger than"):
|
||||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||||
sampling_params)
|
sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("common_llm_kwargs",
|
|
||||||
[{
|
|
||||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
"speculative_model": "JackFram/llama-68m",
|
|
||||||
"num_speculative_tokens": 5,
|
|
||||||
"enable_chunked_prefill": "True",
|
|
||||||
}])
|
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
|
||||||
{
|
|
||||||
"tensor_parallel_size": 2,
|
|
||||||
"speculative_draft_tensor_parallel_size": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tensor_parallel_size": 4,
|
|
||||||
"speculative_draft_tensor_parallel_size": 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tensor_parallel_size": 8,
|
|
||||||
"speculative_draft_tensor_parallel_size": 8,
|
|
||||||
},
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("seed", [1])
|
|
||||||
def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one(
|
|
||||||
test_llm_generator):
|
|
||||||
"""Verify that speculative decoding fails if chunked prefill is enabled for
|
|
||||||
draft model with tensor parallelism of more than 1.
|
|
||||||
"""
|
|
||||||
output_len = 128
|
|
||||||
temperature = 0.0
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"Hello, my name is",
|
|
||||||
]
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
max_tokens=output_len,
|
|
||||||
ignore_eos=True,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="with tensor parallel size 1"):
|
|
||||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
|
||||||
sampling_params)
|
|
||||||
|
|||||||
@ -115,3 +115,60 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
|||||||
max_output_len=32,
|
max_output_len=32,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[[
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"--enforce-eager",
|
||||||
|
"--tensor_parallel_size",
|
||||||
|
"2",
|
||||||
|
|
||||||
|
# precision
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
]])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"per_test_common_llm_kwargs",
|
||||||
|
[["--enable-chunked-prefill", "False"],
|
||||||
|
[
|
||||||
|
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
|
||||||
|
"--max-num-seqs", "4"
|
||||||
|
]])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
|
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||||
|
[("JackFram/llama-68m", [
|
||||||
|
"--speculative-model",
|
||||||
|
"JackFram/llama-68m",
|
||||||
|
"--num_speculative-tokens",
|
||||||
|
"3",
|
||||||
|
]),
|
||||||
|
("JackFram/llama-68m", [
|
||||||
|
"--speculative-model",
|
||||||
|
"JackFram/llama-68m",
|
||||||
|
"--num_speculative-tokens",
|
||||||
|
"3",
|
||||||
|
"--speculative-draft-tensor-parallel-size",
|
||||||
|
"1",
|
||||||
|
])])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, seed: int):
|
||||||
|
"""Verify spec decode works well with same and different TP size for
|
||||||
|
the draft model with chunked prefill.
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test_tp(model,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=32,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -867,7 +867,8 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
|||||||
target_group_metadata_list = prefill + decodes
|
target_group_metadata_list = prefill + decodes
|
||||||
execute_model_req = ExecuteModelRequest(
|
execute_model_req = ExecuteModelRequest(
|
||||||
seq_group_metadata_list=target_group_metadata_list,
|
seq_group_metadata_list=target_group_metadata_list,
|
||||||
num_lookahead_slots=k)
|
# For prefill only batches we expect num_lookahead_slots = 0.
|
||||||
|
num_lookahead_slots=k if n_decodes > 0 else 0)
|
||||||
|
|
||||||
target_token_ids = torch.randint(low=0,
|
target_token_ids = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
|
|||||||
@ -1409,16 +1409,6 @@ class SpeculativeConfig:
|
|||||||
draft_hf_config
|
draft_hf_config
|
||||||
)
|
)
|
||||||
|
|
||||||
if (enable_chunked_prefill and \
|
|
||||||
speculative_draft_tensor_parallel_size != 1):
|
|
||||||
# TODO - Investigate why the error reported in
|
|
||||||
# https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258
|
|
||||||
# is happening and re-enable it.
|
|
||||||
raise ValueError(
|
|
||||||
"Chunked prefill and speculative decoding can be enabled "
|
|
||||||
"simultaneously only for draft models with tensor "
|
|
||||||
"parallel size 1.")
|
|
||||||
|
|
||||||
draft_model_config.max_model_len = (
|
draft_model_config.max_model_len = (
|
||||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||||
speculative_max_model_len,
|
speculative_max_model_len,
|
||||||
|
|||||||
@ -1201,15 +1201,25 @@ class Scheduler:
|
|||||||
# Update swapped requests.
|
# Update swapped requests.
|
||||||
self.swapped.extend(running_scheduled.swapped_out)
|
self.swapped.extend(running_scheduled.swapped_out)
|
||||||
# Put prefills first due to Attention backend ordering assumption.
|
# Put prefills first due to Attention backend ordering assumption.
|
||||||
|
scheduled_seq_groups = (prefills.seq_groups +
|
||||||
|
running_scheduled.prefill_seq_groups +
|
||||||
|
swapped_in.prefill_seq_groups +
|
||||||
|
running_scheduled.decode_seq_groups +
|
||||||
|
swapped_in.decode_seq_groups)
|
||||||
|
num_prefill_groups = (len(prefills.seq_groups) +
|
||||||
|
len(swapped_in.prefill_seq_groups) +
|
||||||
|
len(running_scheduled.prefill_seq_groups))
|
||||||
|
# If all prompts, then we set num_lookahead_slots to 0
|
||||||
|
# this allows us to go through the `no_spec` path in
|
||||||
|
# `spec_decode_worker.py`
|
||||||
|
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
|
||||||
|
num_lookahead_slots = (0 if
|
||||||
|
(all_prefills
|
||||||
|
and not self.scheduler_config.is_multi_step)
|
||||||
|
else running_scheduled.num_lookahead_slots)
|
||||||
return SchedulerOutputs(
|
return SchedulerOutputs(
|
||||||
scheduled_seq_groups=(prefills.seq_groups +
|
scheduled_seq_groups=scheduled_seq_groups,
|
||||||
running_scheduled.prefill_seq_groups +
|
num_prefill_groups=num_prefill_groups,
|
||||||
swapped_in.prefill_seq_groups +
|
|
||||||
running_scheduled.decode_seq_groups +
|
|
||||||
swapped_in.decode_seq_groups),
|
|
||||||
num_prefill_groups=(len(prefills.seq_groups) +
|
|
||||||
len(swapped_in.prefill_seq_groups) +
|
|
||||||
len(running_scheduled.prefill_seq_groups)),
|
|
||||||
num_batched_tokens=budget.num_batched_tokens +
|
num_batched_tokens=budget.num_batched_tokens +
|
||||||
budget.num_cached_tokens,
|
budget.num_cached_tokens,
|
||||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||||
@ -1218,7 +1228,7 @@ class Scheduler:
|
|||||||
swapped_in.blocks_to_copy,
|
swapped_in.blocks_to_copy,
|
||||||
ignored_seq_groups=prefills.ignored_seq_groups +
|
ignored_seq_groups=prefills.ignored_seq_groups +
|
||||||
swapped_in.infeasible_seq_groups,
|
swapped_in.infeasible_seq_groups,
|
||||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
running_queue_size=len(self.running),
|
running_queue_size=len(self.running),
|
||||||
preempted=(len(running_scheduled.preempted) +
|
preempted=(len(running_scheduled.preempted) +
|
||||||
len(running_scheduled.swapped_out)),
|
len(running_scheduled.swapped_out)),
|
||||||
|
|||||||
@ -408,7 +408,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
disable_all_speculation = self._should_disable_all_speculation(
|
disable_all_speculation = self._should_disable_all_speculation(
|
||||||
execute_model_req)
|
execute_model_req)
|
||||||
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
||||||
|
all_prompt = True
|
||||||
|
atleast_one_prompt = False
|
||||||
|
all_zero_spec_tokens = True
|
||||||
|
for sgm in execute_model_req.seq_group_metadata_list:
|
||||||
|
all_prompt = all_prompt and sgm.is_prompt
|
||||||
|
atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
|
||||||
|
all_zero_spec_tokens = all_zero_spec_tokens and (
|
||||||
|
sgm.num_speculative_tokens == 0)
|
||||||
|
|
||||||
|
if all_prompt and execute_model_req.seq_group_metadata_list:
|
||||||
|
assert num_lookahead_slots == 0, (
|
||||||
|
"Prompt only runs should have num_lookahead_slots equal to 0. "
|
||||||
|
"This should never happen, please file a bug at "
|
||||||
|
"https://github.com/vllm-project/vllm/issues")
|
||||||
# Speculative decoding is disabled in the following cases:
|
# Speculative decoding is disabled in the following cases:
|
||||||
# 1. Prefill phase: Speculative decoding is not
|
# 1. Prefill phase: Speculative decoding is not
|
||||||
# used during the prefill phase.
|
# used during the prefill phase.
|
||||||
@ -419,11 +432,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# In any of these cases, the proposer and scorer workers
|
# In any of these cases, the proposer and scorer workers
|
||||||
# are called normally.
|
# are called normally.
|
||||||
# We expect `num_speculative_tokens` to be None for prefills.
|
# We expect `num_speculative_tokens` to be None for prefills.
|
||||||
no_spec = all(
|
no_spec = (num_lookahead_slots == 0 or disable_all_speculation
|
||||||
sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list
|
or all_zero_spec_tokens)
|
||||||
) or num_lookahead_slots == 0 or disable_all_speculation or all(
|
|
||||||
sgm.num_speculative_tokens == 0
|
|
||||||
for sgm in execute_model_req.seq_group_metadata_list)
|
|
||||||
|
|
||||||
# Broadcast how many lookahead slots are scheduled for this step, and
|
# Broadcast how many lookahead slots are scheduled for this step, and
|
||||||
# whether all speculation is disabled, to all non-driver workers.
|
# whether all speculation is disabled, to all non-driver workers.
|
||||||
@ -442,6 +452,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
num_lookahead_slots=num_lookahead_slots,
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
no_spec=no_spec,
|
no_spec=no_spec,
|
||||||
disable_all_speculation=disable_all_speculation,
|
disable_all_speculation=disable_all_speculation,
|
||||||
|
# When both chunked prefill and speculative decoding are enabled
|
||||||
|
# it is possible that the same batch contains both prefill
|
||||||
|
# and decodes. If that happens in the scorer we run the batch
|
||||||
|
# as one single forward pass. However, in the proposer we
|
||||||
|
# run them as 2 different batches - one for prefill and
|
||||||
|
# the other for decodes. The variable indicates to the non-driver
|
||||||
|
# worker that there are prefills as part of the speculative batch
|
||||||
|
# and hence it needs to run an extra prefill forward pass.
|
||||||
|
run_spec_proposer_for_prefill=atleast_one_prompt,
|
||||||
)
|
)
|
||||||
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
||||||
|
|
||||||
@ -653,6 +672,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
if not data["no_spec"]:
|
if not data["no_spec"]:
|
||||||
self.scorer_worker.execute_model()
|
self.scorer_worker.execute_model()
|
||||||
|
if data["run_spec_proposer_for_prefill"]:
|
||||||
|
self.proposer_worker.execute_model()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user