mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:45:00 +08:00
[Speculators] Move tests + fix integration (#27308)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: rahul-tuli <rtuli@redhat.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
8b62495076
commit
413ef7a3b4
@ -121,6 +121,86 @@ def test_ngram_correctness(
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_path",
|
||||||
|
[
|
||||||
|
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
|
||||||
|
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||||
|
],
|
||||||
|
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
|
||||||
|
)
|
||||||
|
def test_speculators_model_integration(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_path: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that speculators models work with the simplified integration.
|
||||||
|
|
||||||
|
This verifies the `vllm serve <speculator-model>` use case where
|
||||||
|
speculative config is automatically detected from the model config
|
||||||
|
without requiring explicit --speculative-config argument.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
1. Speculator model is correctly detected
|
||||||
|
2. Verifier model is extracted from speculator config
|
||||||
|
3. Speculative decoding is automatically enabled
|
||||||
|
4. Text generation works correctly
|
||||||
|
5. Output matches reference (non-speculative) generation
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
|
# Generate test prompts
|
||||||
|
test_prompts = get_test_prompts(mm_enabled=False)
|
||||||
|
|
||||||
|
# First run: Direct speculator model (simplified integration)
|
||||||
|
spec_llm = LLM(model=model_path, max_model_len=1024)
|
||||||
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
|
|
||||||
|
# Verify speculative config was auto-detected
|
||||||
|
assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
|
||||||
|
f"Speculative config should be auto-detected for {model_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
spec_config = spec_llm.llm_engine.vllm_config.speculative_config
|
||||||
|
assert spec_config.num_speculative_tokens > 0, (
|
||||||
|
f"Expected positive speculative tokens, "
|
||||||
|
f"got {spec_config.num_speculative_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify draft model is set to the speculator model
|
||||||
|
assert spec_config.model == model_path, (
|
||||||
|
f"Draft model should be {model_path}, got {spec_config.model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract verifier model for reference run
|
||||||
|
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
|
||||||
|
|
||||||
|
del spec_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
# Second run: Reference without speculative decoding
|
||||||
|
ref_llm = LLM(model=verifier_model, max_model_len=1024)
|
||||||
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
|
del ref_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
# Compare outputs
|
||||||
|
matches = sum(
|
||||||
|
1
|
||||||
|
for ref, spec in zip(ref_outputs, spec_outputs)
|
||||||
|
if ref.outputs[0].text == spec.outputs[0].text
|
||||||
|
)
|
||||||
|
|
||||||
|
# Heuristic: expect at least 66% of prompts to match exactly
|
||||||
|
assert matches >= int(0.66 * len(ref_outputs)), (
|
||||||
|
f"Only {matches}/{len(ref_outputs)} outputs matched. "
|
||||||
|
f"Expected at least {int(0.66 * len(ref_outputs))} matches."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["model_setup", "mm_enabled"],
|
["model_setup", "mm_enabled"],
|
||||||
[
|
[
|
||||||
|
|||||||
@ -22,10 +22,6 @@ from vllm.model_executor.models.interfaces import supports_eagle3
|
|||||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
||||||
id="qwen3-eagle3-speculator-w4a16-verifier",
|
id="qwen3-eagle3-speculator-w4a16-verifier",
|
||||||
),
|
),
|
||||||
pytest.param(
|
|
||||||
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
|
|
||||||
id="llama3-eagl3-multiple-layers",
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_eagle3_speculators_model(
|
def test_eagle3_speculators_model(
|
||||||
@ -81,7 +81,7 @@ from vllm.transformers_utils.config import (
|
|||||||
is_interleaved,
|
is_interleaved,
|
||||||
maybe_override_with_speculators,
|
maybe_override_with_speculators,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file, is_s3
|
||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.utils.network_utils import get_ip
|
from vllm.utils.network_utils import get_ip
|
||||||
@ -1305,20 +1305,26 @@ class EngineArgs:
|
|||||||
|
|
||||||
device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
|
device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
|
||||||
|
|
||||||
|
# Check if the model is a speculator and override model/tokenizer/config
|
||||||
|
# BEFORE creating ModelConfig, so the config is created with the target model
|
||||||
|
# Skip speculator detection for S3 models since HuggingFace cannot load
|
||||||
|
# configs directly from S3 URLs. S3 models can still use speculators with
|
||||||
|
# explicit --speculative-config.
|
||||||
|
if not is_s3(self.model):
|
||||||
|
(self.model, self.tokenizer, self.speculative_config) = (
|
||||||
|
maybe_override_with_speculators(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
revision=self.revision,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
vllm_speculative_config=self.speculative_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model_config = self.create_model_config()
|
model_config = self.create_model_config()
|
||||||
self.model = model_config.model
|
self.model = model_config.model
|
||||||
self.tokenizer = model_config.tokenizer
|
self.tokenizer = model_config.tokenizer
|
||||||
|
|
||||||
(self.model, self.tokenizer, self.speculative_config) = (
|
|
||||||
maybe_override_with_speculators(
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
revision=self.revision,
|
|
||||||
trust_remote_code=self.trust_remote_code,
|
|
||||||
vllm_speculative_config=self.speculative_config,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
|
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
|
||||||
# and fall back to V0 for experimental or unsupported features.
|
# and fall back to V0 for experimental or unsupported features.
|
||||||
# * If VLLM_USE_V1=1, we enable V1 for supported + experimental
|
# * If VLLM_USE_V1=1, we enable V1 for supported + experimental
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user