diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index 7608618502966..f15a9224c0030 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -14,10 +14,13 @@ MAIN_MODEL = "JackFram/llama-68m" @pytest.mark.parametrize( "common_llm_kwargs", [{ + "model_name": "JackFram/llama-68m", # Verify equality when cuda graphs allowed. "enforce_eager": False, - "model_name": "JackFram/llama-68m", + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -59,6 +62,9 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", []) @pytest.mark.parametrize( @@ -117,6 +123,9 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 1629c69f8ee9d..4de7ee05605ad 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -17,7 +17,10 @@ from .conftest import run_equality_correctness_test "model_name": "JackFram/llama-160m", # Skip cuda graph recording for fast test. - "enforce_eager": True + "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -75,6 +78,9 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -128,6 +134,9 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -182,6 +191,9 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -256,8 +268,12 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, "common_llm_kwargs", [{ "model_name": "JackFram/llama-160m", + # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 9f778ca8d179b..0e41d93eaa190 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -494,6 +494,9 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # Precision + "dtype": PRECISION, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 5aefc1df84980..58d1a6ca7adda 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -40,6 +40,9 @@ from .conftest import run_equality_correctness_test # Print spec metrics. "disable_log_stats": False, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -97,6 +100,9 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Print spec metrics. "disable_log_stats": False, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -160,6 +166,9 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -221,6 +230,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -281,6 +293,9 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -337,6 +352,9 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index d219b5228ac31..c551ecd68ef86 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -74,6 +74,7 @@ class EAGLE(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + self.dtype = vllm_config.model_config.dtype self.config = config architectures = getattr(self.config.model, "architectures", []) @@ -250,7 +251,7 @@ class EAGLE(nn.Module): lm_head_weight = torch.zeros( self.lm_head.org_vocab_size, self.lm_head.embedding_dim, - dtype=self.config.torch_dtype, + dtype=self.dtype, ) weight_loader = getattr(self.lm_head.weight, "weight_loader",