mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 13:47:00 +08:00
[Model][Speculative Decoding] DeepSeek MTP spec decode (#12755)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
parent
983a40a8bb
commit
f525c0be8b
@ -2,7 +2,7 @@
|
|||||||
# adding a new command to an existing step. See different options here for examples.
|
# adding a new command to an existing step. See different options here for examples.
|
||||||
|
|
||||||
# This script will be feed into Jinja template in `test-template-aws.j2` at
|
# This script will be feed into Jinja template in `test-template-aws.j2` at
|
||||||
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
|
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
|
||||||
# to generate the final pipeline yaml file.
|
# to generate the final pipeline yaml file.
|
||||||
|
|
||||||
# Documentation
|
# Documentation
|
||||||
@ -15,7 +15,7 @@
|
|||||||
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
|
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
|
||||||
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
|
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
|
||||||
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
|
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
|
||||||
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
|
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
|
||||||
# in this case, commands must be specified. the first command runs on first host, the second
|
# in this case, commands must be specified. the first command runs on first host, the second
|
||||||
# command runs on the second host.
|
# command runs on the second host.
|
||||||
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
|
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
|
||||||
@ -24,8 +24,8 @@
|
|||||||
# When adding a test
|
# When adding a test
|
||||||
# - If the test belong to an existing group, add it there
|
# - If the test belong to an existing group, add it there
|
||||||
# - If the test is short, add to any existing step
|
# - If the test is short, add to any existing step
|
||||||
# - If the test takes more than 10min, then it is okay to create a new step.
|
# - If the test takes more than 10min, then it is okay to create a new step.
|
||||||
# Note that all steps execute in parallel.
|
# Note that all steps execute in parallel.
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
##### fast check tests #####
|
##### fast check tests #####
|
||||||
@ -145,14 +145,14 @@ steps:
|
|||||||
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
|
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
|
||||||
|
|
||||||
- label: Metrics, Tracing Test # 10min
|
- label: Metrics, Tracing Test # 10min
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
fast_check: true
|
fast_check: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/metrics
|
- tests/metrics
|
||||||
- tests/tracing
|
- tests/tracing
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s metrics
|
- pytest -v -s metrics
|
||||||
- "pip install \
|
- "pip install \
|
||||||
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
||||||
'opentelemetry-api>=1.26.0,<1.27.0' \
|
'opentelemetry-api>=1.26.0,<1.27.0' \
|
||||||
@ -254,7 +254,7 @@ steps:
|
|||||||
- vllm/model_executor/guided_decoding
|
- vllm/model_executor/guided_decoding
|
||||||
- tests/test_logits_processor
|
- tests/test_logits_processor
|
||||||
- tests/model_executor/test_guided_processors
|
- tests/model_executor/test_guided_processors
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s test_logits_processor.py
|
- pytest -v -s test_logits_processor.py
|
||||||
- pytest -v -s model_executor/test_guided_processors.py
|
- pytest -v -s model_executor/test_guided_processors.py
|
||||||
|
|
||||||
@ -265,7 +265,7 @@ steps:
|
|||||||
- vllm/model_executor/models/eagle.py
|
- vllm/model_executor/models/eagle.py
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py
|
||||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||||
|
|
||||||
- label: LoRA Test %N # 15min each
|
- label: LoRA Test %N # 15min each
|
||||||
@ -580,7 +580,7 @@ steps:
|
|||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||||
- pytest -v -s -x lora/test_long_context.py
|
- pytest -v -s -x lora/test_long_context.py
|
||||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||||
# requires multi-GPU testing for validation.
|
# requires multi-GPU testing for validation.
|
||||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||||
- pytest -v -s -x lora/test_llama_tp.py
|
- pytest -v -s -x lora/test_llama_tp.py
|
||||||
@ -605,7 +605,7 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/weight_loading
|
- tests/weight_loading
|
||||||
commands:
|
commands:
|
||||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||||
|
|
||||||
|
|
||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
@ -617,7 +617,7 @@ steps:
|
|||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
commands:
|
commands:
|
||||||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||||
|
|||||||
@ -296,6 +296,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
|
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
|
||||||
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
|
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
|
||||||
speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
|
speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
|
||||||
|
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
|
||||||
|
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
|
||||||
|
trust_remote_code=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
_FALLBACK_MODEL = {
|
_FALLBACK_MODEL = {
|
||||||
|
|||||||
318
tests/spec_decode/e2e/test_mtp_correctness.py
Normal file
318
tests/spec_decode/e2e/test_mtp_correctness.py
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""This docstring details important information on the testing methodology.
|
||||||
|
|
||||||
|
Most of the tests rely on "greedy equality", where we expect the output of
|
||||||
|
speculative decoding on a sequence to exactly match the output of normal non-
|
||||||
|
speculative decoding.
|
||||||
|
|
||||||
|
Since speculative decoding with rejection sampling guarantees that the output
|
||||||
|
distribution matches the target model's output distribution (up to hardware
|
||||||
|
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||||
|
equality.
|
||||||
|
|
||||||
|
However, we still need to verify below scenario could be passed:
|
||||||
|
* Batch size 1 greedy equality
|
||||||
|
* Batch size >1 greedy equality
|
||||||
|
* Test greedy equality under preemption
|
||||||
|
* Test greedy equality under various number of speculative tokens.
|
||||||
|
|
||||||
|
With those tests, we can say at least, mtp would not break the
|
||||||
|
correctess for the target model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
|
# main model
|
||||||
|
MAIN_MODEL = "luccafong/deepseek_mtp_main_random"
|
||||||
|
|
||||||
|
# max. number of speculative tokens: this corresponds to
|
||||||
|
# num_nextn_predict_layers in the config.json of the speculator model.
|
||||||
|
MAX_SPEC_TOKENS = 1
|
||||||
|
|
||||||
|
# precision
|
||||||
|
PRECISION = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
|
# GPU memory utilization
|
||||||
|
"gpu_memory_utilization": 0.85
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
|
|
||||||
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size, output_len, seed)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
|
# GPU memory utilization
|
||||||
|
"gpu_memory_utilization": 0.85
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"disable_logprobs_during_spec_decoding": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
|
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int, seed: int,
|
||||||
|
logprobs: int):
|
||||||
|
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
output_len,
|
||||||
|
seed,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"enforce_eager": False,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
"gpu_memory_utilization": 0.85
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size: int,
|
||||||
|
output_len: int, seed: int):
|
||||||
|
"""Verify greedy equality with cuda graph enabled and different
|
||||||
|
batch sizes."""
|
||||||
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size, output_len, seed)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"block_size": 8,
|
||||||
|
# 2 for small prompt, 256//8 for generated.
|
||||||
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
|
# GPU memory utilization
|
||||||
|
"gpu_memory_utilization": 0.9
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use small output len for fast test.
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||||
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
|
generation.
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size, output_len, seed)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
|
# GPU memory utilization
|
||||||
|
"gpu_memory_utilization": 0.9
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_llm_kwargs",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
}
|
||||||
|
# Try a range of num. speculative tokens
|
||||||
|
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
|
"""Verify that mtp speculative decoding produces exact equality
|
||||||
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size, output_len, seed)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
|
# GPU memory utilization
|
||||||
|
"gpu_memory_utilization": 0.9
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"speculative_disable_by_batch_size": 4
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
|
"""Verify that mtp speculative decoding produces exact equality
|
||||||
|
to without spec decode when speculation is disabled for large
|
||||||
|
batch sizes.
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size, output_len, seed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
pytest.main([__file__])
|
||||||
@ -763,7 +763,7 @@ class ModelConfig:
|
|||||||
def is_deepseek_mla(self) -> bool:
|
def is_deepseek_mla(self) -> bool:
|
||||||
return (hasattr(self.hf_text_config, "model_type")) \
|
return (hasattr(self.hf_text_config, "model_type")) \
|
||||||
and (self.hf_text_config.model_type in \
|
and (self.hf_text_config.model_type in \
|
||||||
('deepseek_v2', 'deepseek_v3'))\
|
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
|
||||||
and (self.hf_text_config.kv_lora_rank is not None)
|
and (self.hf_text_config.kv_lora_rank is not None)
|
||||||
|
|
||||||
def get_head_size(self) -> int:
|
def get_head_size(self) -> int:
|
||||||
@ -856,8 +856,12 @@ class ModelConfig:
|
|||||||
def get_layers_start_end_indices(
|
def get_layers_start_end_indices(
|
||||||
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
|
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
|
||||||
from vllm.distributed.utils import get_pp_indices
|
from vllm.distributed.utils import get_pp_indices
|
||||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
if self.hf_text_config.model_type == "deepseek_mtp":
|
||||||
"num_hidden_layers", 0)
|
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||||
|
"num_nextn_predict_layers", 0)
|
||||||
|
else:
|
||||||
|
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||||
|
"num_hidden_layers", 0)
|
||||||
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
|
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
|
||||||
pp_size = parallel_config.pipeline_parallel_size
|
pp_size = parallel_config.pipeline_parallel_size
|
||||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||||
@ -1689,6 +1693,18 @@ class SpeculativeConfig:
|
|||||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||||
|
if hf_config.model_type == "deepseek_v3":
|
||||||
|
hf_config.model_type = "deepseek_mtp"
|
||||||
|
if hf_config.model_type == "deepseek_mtp":
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["DeepSeekMTPModel"]
|
||||||
|
})
|
||||||
|
return hf_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_create_spec_config(
|
def maybe_create_spec_config(
|
||||||
target_model_config: ModelConfig,
|
target_model_config: ModelConfig,
|
||||||
@ -1771,12 +1787,18 @@ class SpeculativeConfig:
|
|||||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||||
the necessary conditions are met, else None.
|
the necessary conditions are met, else None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if speculative_model is None:
|
if speculative_model is None:
|
||||||
if num_speculative_tokens is not None:
|
if num_speculative_tokens is not None:
|
||||||
raise ValueError("num_speculative_tokens was provided without "
|
if target_model_config.hf_text_config.model_type \
|
||||||
"speculative_model.")
|
== "deepseek_v3":
|
||||||
return None
|
# use the draft model from the same model:
|
||||||
|
speculative_model = target_model_config.model
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"num_speculative_tokens was provided without "
|
||||||
|
"speculative_model.")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
if (speculative_disable_by_batch_size is not None
|
if (speculative_disable_by_batch_size is not None
|
||||||
and speculative_disable_by_batch_size < 2):
|
and speculative_disable_by_batch_size < 2):
|
||||||
@ -1830,6 +1852,7 @@ class SpeculativeConfig:
|
|||||||
max_seq_len_to_capture=target_model_config.
|
max_seq_len_to_capture=target_model_config.
|
||||||
max_seq_len_to_capture,
|
max_seq_len_to_capture,
|
||||||
max_logprobs=target_model_config.max_logprobs,
|
max_logprobs=target_model_config.max_logprobs,
|
||||||
|
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
draft_hf_config = draft_model_config.hf_config
|
draft_hf_config = draft_model_config.hf_config
|
||||||
@ -1846,7 +1869,6 @@ class SpeculativeConfig:
|
|||||||
if (num_speculative_tokens is not None
|
if (num_speculative_tokens is not None
|
||||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
||||||
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
||||||
|
|
||||||
n_predict = getattr(draft_hf_config, "n_predict", None)
|
n_predict = getattr(draft_hf_config, "n_predict", None)
|
||||||
if n_predict is not None:
|
if n_predict is not None:
|
||||||
if num_speculative_tokens is None:
|
if num_speculative_tokens is None:
|
||||||
@ -1960,8 +1982,9 @@ class SpeculativeConfig:
|
|||||||
speculative_draft_tensor_parallel_size = 1
|
speculative_draft_tensor_parallel_size = 1
|
||||||
if target_parallel_config.tensor_parallel_size > 1:
|
if target_parallel_config.tensor_parallel_size > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MLPSpeculator cannot currently be run with tp>1; "
|
"%s cannot currently be run with tp>1; "
|
||||||
"setting speculative_draft_tensor_parallel_size=1")
|
"setting speculative_draft_tensor_parallel_size=1",
|
||||||
|
draft_hf_config.model_type)
|
||||||
else:
|
else:
|
||||||
speculative_draft_tensor_parallel_size = \
|
speculative_draft_tensor_parallel_size = \
|
||||||
target_parallel_config.tensor_parallel_size
|
target_parallel_config.tensor_parallel_size
|
||||||
|
|||||||
284
vllm/model_executor/models/deepseek_mtp.py
Normal file
284
vllm/model_executor/models/deepseek_mtp.py
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .deepseek_v2 import (DeepseekV2DecoderLayer,
|
||||||
|
get_spec_layer_idx_from_weight_name)
|
||||||
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class SharedHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.norm(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
prefix: str,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||||
|
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
|
||||||
|
cache_config, quant_config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
previous_hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
spec_step_index: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
assert inputs_embeds is not None
|
||||||
|
# masking inputs at position 0, as not needed by MTP
|
||||||
|
inputs_embeds[positions == 0] = 0
|
||||||
|
inputs_embeds = self.enorm(inputs_embeds)
|
||||||
|
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.eh_proj(
|
||||||
|
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||||
|
|
||||||
|
hidden_states, residual = self.mtp_block(positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
residual=None)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return self.shared_head(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekMultiTokenPredictor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||||
|
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||||
|
# to map the exact layer index from weights
|
||||||
|
self.layers = torch.nn.ModuleDict({
|
||||||
|
str(idx):
|
||||||
|
DeepSeekMultiTokenPredictorLayer(
|
||||||
|
config,
|
||||||
|
f"{prefix}.layers.{idx}",
|
||||||
|
model_config=vllm_config.model_config,
|
||||||
|
cache_config=vllm_config.cache_config,
|
||||||
|
quant_config=vllm_config.quant_config,
|
||||||
|
)
|
||||||
|
for idx in range(self.mtp_start_layer_idx,
|
||||||
|
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||||
|
})
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
previous_hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
spec_step_idx: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches[spec_step_idx],
|
||||||
|
attn_metadata,
|
||||||
|
previous_hidden_states,
|
||||||
|
inputs_embeds,
|
||||||
|
spec_step_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
spec_step_idx: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
|
||||||
|
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||||
|
hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekMTP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.config = vllm_config.model_config.hf_config
|
||||||
|
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "model"))
|
||||||
|
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
previous_hidden_states: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
spec_step_idx: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, previous_hidden_states,
|
||||||
|
inputs_embeds, spec_step_idx)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
spec_step_idx: int = 0,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.model.compute_logits(hidden_states, sampling_metadata,
|
||||||
|
spec_step_idx)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.n_routed_experts)
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||||
|
if spec_layer is None:
|
||||||
|
continue
|
||||||
|
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||||
|
# Since we handle the experts below in expert_params_mapping,
|
||||||
|
# we need to skip here BEFORE we update the name, otherwise
|
||||||
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||||
|
# will then be updated below in expert_params_mapping
|
||||||
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||||
|
if (("mlp.experts." in name) and name not in params_dict):
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||||
|
"""
|
||||||
|
Rewrite the weight name to match the format of the original model.
|
||||||
|
Add .mtp_block for modules in transformer layer block for spec layer
|
||||||
|
"""
|
||||||
|
spec_layer_weight_names = [
|
||||||
|
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
||||||
|
]
|
||||||
|
spec_layer_weight = False
|
||||||
|
for weight_name in spec_layer_weight_names:
|
||||||
|
if weight_name in name:
|
||||||
|
spec_layer_weight = True
|
||||||
|
break
|
||||||
|
if not spec_layer_weight:
|
||||||
|
# treat rest weights as weights for transformer layer block
|
||||||
|
name = name.replace(f"model.layers.{spec_layer}.",
|
||||||
|
f"model.layers.{spec_layer}.mtp_block.")
|
||||||
|
return name
|
||||||
@ -732,13 +732,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# TODO(simon): support nextn predict layers
|
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||||
if hasattr(self.config, "num_nextn_predict_layers"
|
if spec_layer is not None:
|
||||||
) and self.config.num_nextn_predict_layers > 0:
|
continue # skip spec decode layers for main model
|
||||||
assert self.config.num_nextn_predict_layers == 1
|
|
||||||
layer_idx = self.config.num_hidden_layers
|
|
||||||
if name.startswith(f"model.layers.{layer_idx}"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
@ -805,3 +801,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
||||||
|
weight_name: str) -> Optional[int]:
|
||||||
|
if hasattr(config,
|
||||||
|
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||||
|
> 0):
|
||||||
|
layer_idx = config.num_hidden_layers
|
||||||
|
for i in range(config.num_nextn_predict_layers):
|
||||||
|
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
|
||||||
|
return layer_idx + i
|
||||||
|
return None
|
||||||
|
|||||||
@ -187,6 +187,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
|
|
||||||
_SPECULATIVE_DECODING_MODELS = {
|
_SPECULATIVE_DECODING_MODELS = {
|
||||||
"EAGLEModel": ("eagle", "EAGLE"),
|
"EAGLEModel": ("eagle", "EAGLE"),
|
||||||
|
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||||
"MedusaModel": ("medusa", "Medusa"),
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1307,6 +1307,8 @@ class ExecuteModelRequest(
|
|||||||
previous_hidden_states: Optional[HiddenStates] = None
|
previous_hidden_states: Optional[HiddenStates] = None
|
||||||
# The number of forward steps to run.
|
# The number of forward steps to run.
|
||||||
num_steps: int = 1
|
num_steps: int = 1
|
||||||
|
# The step index for spec model input.
|
||||||
|
spec_step_idx: Optional[int] = None
|
||||||
# Finished request ids since last step.
|
# Finished request ids since last step.
|
||||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||||
# The last sampled token ids for multi step decoding.
|
# The last sampled token ids for multi step decoding.
|
||||||
|
|||||||
@ -153,7 +153,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# TODO: Add support for other attn backends
|
# TODO: Add support for other attn backends
|
||||||
if self.attn_backend.get_name() != "FLASH_ATTN":
|
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# TODO: Add support for LORA
|
# TODO: Add support for LORA
|
||||||
@ -175,6 +175,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
|
**kwargs,
|
||||||
) -> Optional[List[SamplerOutput]]:
|
) -> Optional[List[SamplerOutput]]:
|
||||||
"""Executes num_steps forward passes with advacement of input tensors
|
"""Executes num_steps forward passes with advacement of input tensors
|
||||||
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
||||||
@ -271,10 +272,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
|
||||||
kwargs = {"previous_hidden_states": hidden_states} \
|
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
|
||||||
if previous_hidden_states is not None else {}
|
if previous_hidden_states is not None else {}
|
||||||
|
|
||||||
|
compute_logits_kwargs = {}
|
||||||
# Run model
|
# Run model
|
||||||
|
if hasattr(self.model.config, "num_nextn_predict_layers"):
|
||||||
|
# for DeepSeek MTP only to use the corresponding layer for
|
||||||
|
# each step
|
||||||
|
spec_step_idx = kwargs.get("spec_step_idx", step)
|
||||||
|
model_execute_kwargs["spec_step_idx"] = spec_step_idx
|
||||||
|
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
|
||||||
with set_forward_context(model_input.attn_metadata,
|
with set_forward_context(model_input.attn_metadata,
|
||||||
self.vllm_config):
|
self.vllm_config):
|
||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
@ -285,13 +293,15 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
**kwargs,
|
**model_execute_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
logits = self.model.compute_logits(hidden_states,
|
logits = self.model.compute_logits(hidden_states,
|
||||||
model_input.sampling_metadata)
|
model_input.sampling_metadata,
|
||||||
|
**compute_logits_kwargs)
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
|||||||
@ -108,6 +108,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|||||||
typical_acceptance_sampler_posterior_alpha,
|
typical_acceptance_sampler_posterior_alpha,
|
||||||
disable_logprobs=speculative_config.disable_logprobs,
|
disable_logprobs=speculative_config.disable_logprobs,
|
||||||
disable_log_stats=speculative_config.disable_log_stats,
|
disable_log_stats=speculative_config.disable_log_stats,
|
||||||
|
num_speculative_tokens=speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
return spec_decode_worker
|
return spec_decode_worker
|
||||||
@ -153,10 +154,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
typical_acceptance_sampler_posterior_alpha: float,
|
typical_acceptance_sampler_posterior_alpha: float,
|
||||||
disable_logprobs: bool,
|
disable_logprobs: bool,
|
||||||
disable_log_stats: bool,
|
disable_log_stats: bool,
|
||||||
|
num_speculative_tokens: int,
|
||||||
) -> "SpecDecodeWorker":
|
) -> "SpecDecodeWorker":
|
||||||
|
|
||||||
allow_zero_draft_token_step = True
|
allow_zero_draft_token_step = True
|
||||||
enable_lm_head_weight_load = False
|
enable_lm_head_weight_load = False
|
||||||
|
num_spec_prefill_steps = 1
|
||||||
ngram_prompt_lookup_max = (
|
ngram_prompt_lookup_max = (
|
||||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||||
ngram_prompt_lookup_min = (
|
ngram_prompt_lookup_min = (
|
||||||
@ -179,14 +182,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
elif draft_model_config.hf_config.model_type == "medusa":
|
elif draft_model_config.hf_config.model_type == "medusa":
|
||||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||||
else:
|
else:
|
||||||
if draft_tp == 1:
|
if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
|
||||||
|
"deepseek_mtp":
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
draft_worker_kwargs[
|
draft_worker_kwargs[
|
||||||
"model_runner_cls"] = TP1DraftModelRunner
|
"model_runner_cls"] = TP1DraftModelRunner
|
||||||
else:
|
else:
|
||||||
if draft_model_config.hf_config.model_type == "eagle":
|
if draft_model_config.hf_config.model_type == "eagle":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EAGLE does not support TP > 1 yet")
|
f"{draft_model_config.hf_config.model_type} "
|
||||||
|
"does not support TP > 1 yet")
|
||||||
|
|
||||||
allow_zero_draft_token_step = False
|
allow_zero_draft_token_step = False
|
||||||
|
|
||||||
@ -195,6 +200,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
enable_lm_head_weight_load = True
|
enable_lm_head_weight_load = True
|
||||||
|
|
||||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
|
if draft_model_config.hf_config.model_type == "deepseek_mtp":
|
||||||
|
num_spec_prefill_steps = num_speculative_tokens
|
||||||
|
|
||||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||||
proposer_worker, draft_tp, target_tp)
|
proposer_worker, draft_tp, target_tp)
|
||||||
@ -247,7 +254,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
disable_by_batch_size=disable_by_batch_size,
|
disable_by_batch_size=disable_by_batch_size,
|
||||||
spec_decode_sampler=spec_decode_sampler,
|
spec_decode_sampler=spec_decode_sampler,
|
||||||
allow_zero_draft_token_step=allow_zero_draft_token_step,
|
allow_zero_draft_token_step=allow_zero_draft_token_step,
|
||||||
enable_lm_head_weight_load=enable_lm_head_weight_load)
|
enable_lm_head_weight_load=enable_lm_head_weight_load,
|
||||||
|
num_spec_prefill_steps=num_spec_prefill_steps)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -261,6 +269,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
disable_by_batch_size: Optional[int] = None,
|
disable_by_batch_size: Optional[int] = None,
|
||||||
allow_zero_draft_token_step: Optional[bool] = True,
|
allow_zero_draft_token_step: Optional[bool] = True,
|
||||||
enable_lm_head_weight_load: Optional[bool] = False,
|
enable_lm_head_weight_load: Optional[bool] = False,
|
||||||
|
num_spec_prefill_steps: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a SpecDecodeWorker.
|
Create a SpecDecodeWorker.
|
||||||
@ -293,6 +302,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
draft model is larger than 1 (TODO: #5814)
|
draft model is larger than 1 (TODO: #5814)
|
||||||
enable_lm_head_weight_load: whether to load lm_head weight for
|
enable_lm_head_weight_load: whether to load lm_head weight for
|
||||||
draft models like eagle.
|
draft models like eagle.
|
||||||
|
num_spec_prefill_steps: number of speculative prefill steps to run
|
||||||
|
before the speculative decoding starts. This is only used when
|
||||||
|
the draft model is a deepseek_mtp model that requires prefill
|
||||||
|
kv cache separately for each MTP layer.
|
||||||
"""
|
"""
|
||||||
self.proposer_worker = proposer_worker
|
self.proposer_worker = proposer_worker
|
||||||
self.scorer_worker = scorer_worker
|
self.scorer_worker = scorer_worker
|
||||||
@ -326,6 +339,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.previous_hidden_states: Optional[HiddenStates] = None
|
self.previous_hidden_states: Optional[HiddenStates] = None
|
||||||
self._disable_logprobs = disable_logprobs
|
self._disable_logprobs = disable_logprobs
|
||||||
self._disable_log_stats = disable_log_stats
|
self._disable_log_stats = disable_log_stats
|
||||||
|
self._num_spec_prefill_steps = num_spec_prefill_steps
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
"""Initialize both scorer and proposer models.
|
"""Initialize both scorer and proposer models.
|
||||||
@ -685,8 +699,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
execute_model_req.previous_hidden_states = \
|
execute_model_req.previous_hidden_states = \
|
||||||
prepare_prefill_hidden_states(
|
prepare_prefill_hidden_states(
|
||||||
sampler_output.prefill_hidden_states)
|
sampler_output.prefill_hidden_states)
|
||||||
|
for i in range(self._num_spec_prefill_steps):
|
||||||
self.proposer_worker.execute_model(execute_model_req)
|
execute_model_req.spec_step_idx = i
|
||||||
|
self.proposer_worker.execute_model(execute_model_req)
|
||||||
|
|
||||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||||
|
|||||||
@ -99,6 +99,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
virtual_engine: int = 0
|
virtual_engine: int = 0
|
||||||
async_callback: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||||
|
previous_hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
@ -1649,6 +1650,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
|
**kwargs,
|
||||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||||
if num_steps > 1:
|
if num_steps > 1:
|
||||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||||
@ -1706,6 +1708,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
"finished_requests_ids": model_input.finished_requests_ids,
|
"finished_requests_ids": model_input.finished_requests_ids,
|
||||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||||
} if self.has_inner_state else {}
|
} if self.has_inner_state else {}
|
||||||
|
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||||
|
model_kwargs = {}
|
||||||
|
if previous_hidden_states is not None:
|
||||||
|
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time):
|
and self.observability_config.collect_model_forward_time):
|
||||||
model_forward_start = torch.cuda.Event(enable_timing=True)
|
model_forward_start = torch.cuda.Event(enable_timing=True)
|
||||||
@ -1723,7 +1729,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
**seqlen_agnostic_kwargs)
|
**seqlen_agnostic_kwargs,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time):
|
and self.observability_config.collect_model_forward_time):
|
||||||
@ -1815,7 +1823,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
1. current vLLM instance is KV cache consumer/decode vLLM instance
|
1. current vLLM instance is KV cache consumer/decode vLLM instance
|
||||||
2. this batch is not a profiling run
|
2. this batch is not a profiling run
|
||||||
3. this batch is a prefill run
|
3. this batch is a prefill run
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_input: input to the model executable
|
model_input: input to the model executable
|
||||||
kv_caches: vLLM's paged memory
|
kv_caches: vLLM's paged memory
|
||||||
@ -1840,7 +1848,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
1. current vLLM instance is KV cache producer/prefill vLLM instance
|
1. current vLLM instance is KV cache producer/prefill vLLM instance
|
||||||
2. this batch is not a profiling run
|
2. this batch is not a profiling run
|
||||||
3. this batch is a prefill run
|
3. this batch is a prefill run
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_input: input to the model executable
|
model_input: input to the model executable
|
||||||
kv_caches: vLLM's paged memory
|
kv_caches: vLLM's paged memory
|
||||||
@ -1976,7 +1984,11 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
# Copy the input tensors to the input buffers.
|
# Copy the input tensors to the input buffers.
|
||||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||||
if positions is not None:
|
if positions is not None:
|
||||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
# in some case like MLA, it will reuse positions in metadata
|
||||||
|
# but truncate them to the original size
|
||||||
|
# so the shape is not padded, we need to copy partial only
|
||||||
|
self.input_buffers["positions"][:positions.shape[0]].copy_(
|
||||||
|
positions, non_blocking=True)
|
||||||
|
|
||||||
if self.backend_name != "NO_ATTENTION":
|
if self.backend_name != "NO_ATTENTION":
|
||||||
self.input_buffers["slot_mapping"].copy_(
|
self.input_buffers["slot_mapping"].copy_(
|
||||||
|
|||||||
@ -46,7 +46,10 @@ def _init_attn_metadata_from_tensor_dict(
|
|||||||
valid_attn_kwargs = {}
|
valid_attn_kwargs = {}
|
||||||
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
|
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
|
||||||
if field.name in tensor_dict:
|
if field.name in tensor_dict:
|
||||||
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
|
if field.name == "input_positions":
|
||||||
|
valid_attn_kwargs[field.name] = tensor_dict[field.name]
|
||||||
|
else:
|
||||||
|
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
|
||||||
|
|
||||||
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
|
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
|
||||||
tensor_dict["attn_metadata"] = attn_metadata
|
tensor_dict["attn_metadata"] = attn_metadata
|
||||||
|
|||||||
@ -68,10 +68,10 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
speculative_config = self.speculative_config
|
speculative_config = self.speculative_config
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
speculative_args = {} if speculative_config is None \
|
speculative_args = {} if speculative_config is None \
|
||||||
or (speculative_config.draft_model_config.model ==
|
or (speculative_config.draft_model_config.hf_config.model_type ==
|
||||||
model_config.model) \
|
model_config.hf_config.model_type) \
|
||||||
or (speculative_config.draft_model_config.hf_config.model_type
|
or (speculative_config.draft_model_config.hf_config.model_type
|
||||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \
|
||||||
else {"return_hidden_states": True}
|
else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
|
|||||||
@ -397,6 +397,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
|
|
||||||
model_input, worker_input, kwargs = inputs
|
model_input, worker_input, kwargs = inputs
|
||||||
num_steps = worker_input.num_steps
|
num_steps = worker_input.num_steps
|
||||||
|
if (execute_model_req is not None and execute_model_req.spec_step_idx):
|
||||||
|
kwargs["spec_step_idx"] = execute_model_req.spec_step_idx
|
||||||
|
|
||||||
self.execute_worker(worker_input)
|
self.execute_worker(worker_input)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user