[Speculators] Move tests + fix integration (#27308)

Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Signed-off-by: rahul-tuli <rtuli@redhat.com>
Co-authored-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Dipika Sikka 2025-10-29 03:54:21 -04:00 committed by GitHub
parent 8b62495076
commit 413ef7a3b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 15 deletions

View File

@ -121,6 +121,86 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(
"model_path",
[
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
"RedHatAI/Qwen3-8B-speculator.eagle3",
],
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
)
def test_speculators_model_integration(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_path: str,
):
"""
Test that speculators models work with the simplified integration.
This verifies the `vllm serve <speculator-model>` use case where
speculative config is automatically detected from the model config
without requiring explicit --speculative-config argument.
Tests:
1. Speculator model is correctly detected
2. Verifier model is extracted from speculator config
3. Speculative decoding is automatically enabled
4. Text generation works correctly
5. Output matches reference (non-speculative) generation
"""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# Generate test prompts
test_prompts = get_test_prompts(mm_enabled=False)
# First run: Direct speculator model (simplified integration)
spec_llm = LLM(model=model_path, max_model_len=1024)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
# Verify speculative config was auto-detected
assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
f"Speculative config should be auto-detected for {model_path}"
)
spec_config = spec_llm.llm_engine.vllm_config.speculative_config
assert spec_config.num_speculative_tokens > 0, (
f"Expected positive speculative tokens, "
f"got {spec_config.num_speculative_tokens}"
)
# Verify draft model is set to the speculator model
assert spec_config.model == model_path, (
f"Draft model should be {model_path}, got {spec_config.model}"
)
# Extract verifier model for reference run
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Second run: Reference without speculative decoding
ref_llm = LLM(model=verifier_model, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Compare outputs
matches = sum(
1
for ref, spec in zip(ref_outputs, spec_outputs)
if ref.outputs[0].text == spec.outputs[0].text
)
# Heuristic: expect at least 66% of prompts to match exactly
assert matches >= int(0.66 * len(ref_outputs)), (
f"Only {matches}/{len(ref_outputs)} outputs matched. "
f"Expected at least {int(0.66 * len(ref_outputs))} matches."
)
@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[

View File

@ -22,10 +22,6 @@ from vllm.model_executor.models.interfaces import supports_eagle3
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier",
),
pytest.param(
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
id="llama3-eagl3-multiple-layers",
),
],
)
def test_eagle3_speculators_model(

View File

@ -81,7 +81,7 @@ from vllm.transformers_utils.config import (
is_interleaved,
maybe_override_with_speculators,
)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.transformers_utils.utils import check_gguf_file, is_s3
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip
@ -1305,20 +1305,26 @@ class EngineArgs:
device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
# Check if the model is a speculator and override model/tokenizer/config
# BEFORE creating ModelConfig, so the config is created with the target model
# Skip speculator detection for S3 models since HuggingFace cannot load
# configs directly from S3 URLs. S3 models can still use speculators with
# explicit --speculative-config.
if not is_s3(self.model):
(self.model, self.tokenizer, self.speculative_config) = (
maybe_override_with_speculators(
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
vllm_speculative_config=self.speculative_config,
)
)
model_config = self.create_model_config()
self.model = model_config.model
self.tokenizer = model_config.tokenizer
(self.model, self.tokenizer, self.speculative_config) = (
maybe_override_with_speculators(
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
vllm_speculative_config=self.speculative_config,
)
)
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
# and fall back to V0 for experimental or unsupported features.
# * If VLLM_USE_V1=1, we enable V1 for supported + experimental