mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 03:45:00 +08:00
[Speculative Decoding] MLPSpeculator Tensor Parallel support (1/2) (#6050)
Co-authored-by: Sirej Dua <sirej.dua@databricks.com> Co-authored-by: Sirej Dua <Sirej Dua>
This commit is contained in:
parent
31354e563f
commit
15aba081f3
@ -70,10 +70,6 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
# Use a small model for a fast test.
|
|
||||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
|
||||||
"model": "JackFram/llama-68m",
|
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
@ -88,15 +84,31 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
|||||||
# second run of the test to fail with internal NCCL error.
|
# second run of the test to fail with internal NCCL error.
|
||||||
"use_async": True,
|
"use_async": True,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize(
|
||||||
{
|
"per_test_common_llm_kwargs, test_llm_kwargs",
|
||||||
"speculative_model": "JackFram/llama-68m",
|
[
|
||||||
"num_speculative_tokens": 5,
|
(
|
||||||
"speculative_draft_tensor_parallel_size": 1,
|
{
|
||||||
},
|
# Use a small model for a fast test.
|
||||||
])
|
# Note this is repeated in the test body; to initialize a
|
||||||
|
# tokenizer.
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"speculative_draft_tensor_parallel_size": 1,
|
||||||
|
}),
|
||||||
|
({
|
||||||
|
"model": "ibm-granite/granite-3b-code-instruct",
|
||||||
|
}, {
|
||||||
|
"speculative_model":
|
||||||
|
"ibm-granite/granite-3b-code-instruct-accelerator",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"speculative_draft_tensor_parallel_size": 1,
|
||||||
|
})
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
|
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
|
||||||
|
|||||||
@ -957,12 +957,6 @@ class SpeculativeConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
draft_hf_config = draft_model_config.hf_config
|
draft_hf_config = draft_model_config.hf_config
|
||||||
if (draft_hf_config.model_type == "mlp_speculator"
|
|
||||||
and target_parallel_config.world_size != 1):
|
|
||||||
# MLPSpeculator TP support will be added very soon
|
|
||||||
raise ValueError(
|
|
||||||
"Speculative decoding with mlp_speculator models does not "
|
|
||||||
"yet support distributed inferencing (TP > 1).")
|
|
||||||
|
|
||||||
if (num_speculative_tokens is not None
|
if (num_speculative_tokens is not None
|
||||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
||||||
|
|||||||
@ -113,24 +113,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||||
|
|
||||||
disable_bonus_tokens = True
|
disable_bonus_tokens = True
|
||||||
|
|
||||||
if ngram_prompt_lookup_max > 0:
|
if ngram_prompt_lookup_max > 0:
|
||||||
disable_bonus_tokens = False
|
disable_bonus_tokens = False
|
||||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||||
ngram_prompt_lookup_max)
|
ngram_prompt_lookup_max)
|
||||||
elif draft_worker_kwargs[
|
|
||||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
|
||||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
|
||||||
disable_bonus_tokens = False
|
|
||||||
else:
|
else:
|
||||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||||
'parallel_config']
|
'parallel_config']
|
||||||
draft_tp = draft_parallel_config.tensor_parallel_size
|
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||||
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
if draft_tp == 1:
|
if draft_worker_kwargs[
|
||||||
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
|
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
disable_bonus_tokens = False
|
||||||
|
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||||
|
else:
|
||||||
|
if draft_tp == 1:
|
||||||
|
draft_worker_kwargs[
|
||||||
|
"model_runner_cls"] = TP1DraftModelRunner
|
||||||
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
|
|
||||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||||
proposer_worker, draft_tp, target_tp)
|
proposer_worker, draft_tp, target_tp)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user