mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Speculators] Move tests + fix integration (#27308)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: rahul-tuli <rtuli@redhat.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
8b62495076
commit
413ef7a3b4
@ -121,6 +121,86 @@ def test_ngram_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
|
||||
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||
],
|
||||
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
|
||||
)
|
||||
def test_speculators_model_integration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_path: str,
|
||||
):
|
||||
"""
|
||||
Test that speculators models work with the simplified integration.
|
||||
|
||||
This verifies the `vllm serve <speculator-model>` use case where
|
||||
speculative config is automatically detected from the model config
|
||||
without requiring explicit --speculative-config argument.
|
||||
|
||||
Tests:
|
||||
1. Speculator model is correctly detected
|
||||
2. Verifier model is extracted from speculator config
|
||||
3. Speculative decoding is automatically enabled
|
||||
4. Text generation works correctly
|
||||
5. Output matches reference (non-speculative) generation
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Generate test prompts
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
# First run: Direct speculator model (simplified integration)
|
||||
spec_llm = LLM(model=model_path, max_model_len=1024)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
|
||||
# Verify speculative config was auto-detected
|
||||
assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
|
||||
f"Speculative config should be auto-detected for {model_path}"
|
||||
)
|
||||
|
||||
spec_config = spec_llm.llm_engine.vllm_config.speculative_config
|
||||
assert spec_config.num_speculative_tokens > 0, (
|
||||
f"Expected positive speculative tokens, "
|
||||
f"got {spec_config.num_speculative_tokens}"
|
||||
)
|
||||
|
||||
# Verify draft model is set to the speculator model
|
||||
assert spec_config.model == model_path, (
|
||||
f"Draft model should be {model_path}, got {spec_config.model}"
|
||||
)
|
||||
|
||||
# Extract verifier model for reference run
|
||||
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
|
||||
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Second run: Reference without speculative decoding
|
||||
ref_llm = LLM(model=verifier_model, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Compare outputs
|
||||
matches = sum(
|
||||
1
|
||||
for ref, spec in zip(ref_outputs, spec_outputs)
|
||||
if ref.outputs[0].text == spec.outputs[0].text
|
||||
)
|
||||
|
||||
# Heuristic: expect at least 66% of prompts to match exactly
|
||||
assert matches >= int(0.66 * len(ref_outputs)), (
|
||||
f"Only {matches}/{len(ref_outputs)} outputs matched. "
|
||||
f"Expected at least {int(0.66 * len(ref_outputs))} matches."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
|
||||
@ -22,10 +22,6 @@ from vllm.model_executor.models.interfaces import supports_eagle3
|
||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
||||
id="qwen3-eagle3-speculator-w4a16-verifier",
|
||||
),
|
||||
pytest.param(
|
||||
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
|
||||
id="llama3-eagl3-multiple-layers",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eagle3_speculators_model(
|
||||
@ -81,7 +81,7 @@ from vllm.transformers_utils.config import (
|
||||
is_interleaved,
|
||||
maybe_override_with_speculators,
|
||||
)
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.transformers_utils.utils import check_gguf_file, is_s3
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.network_utils import get_ip
|
||||
@ -1305,20 +1305,26 @@ class EngineArgs:
|
||||
|
||||
device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
|
||||
|
||||
# Check if the model is a speculator and override model/tokenizer/config
|
||||
# BEFORE creating ModelConfig, so the config is created with the target model
|
||||
# Skip speculator detection for S3 models since HuggingFace cannot load
|
||||
# configs directly from S3 URLs. S3 models can still use speculators with
|
||||
# explicit --speculative-config.
|
||||
if not is_s3(self.model):
|
||||
(self.model, self.tokenizer, self.speculative_config) = (
|
||||
maybe_override_with_speculators(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
revision=self.revision,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
vllm_speculative_config=self.speculative_config,
|
||||
)
|
||||
)
|
||||
|
||||
model_config = self.create_model_config()
|
||||
self.model = model_config.model
|
||||
self.tokenizer = model_config.tokenizer
|
||||
|
||||
(self.model, self.tokenizer, self.speculative_config) = (
|
||||
maybe_override_with_speculators(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
revision=self.revision,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
vllm_speculative_config=self.speculative_config,
|
||||
)
|
||||
)
|
||||
|
||||
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
|
||||
# and fall back to V0 for experimental or unsupported features.
|
||||
# * If VLLM_USE_V1=1, we enable V1 for supported + experimental
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user