diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 6d385184d264a..ccc8e745ab371 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -57,6 +57,9 @@ from .conftest import (get_output_from_llm_generator, # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -139,6 +142,9 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, # Print spec metrics. "disable_log_stats": False, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -216,6 +222,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( # Print spec metrics. "disable_log_stats": False, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -279,6 +288,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( [{ # Skip cuda graph recording for fast test. "enforce_eager": True, + + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -464,6 +476,8 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( # Skip cuda graph recording for fast test. "enforce_eager": True, + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -523,6 +537,8 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -589,6 +605,8 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -655,6 +673,8 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -706,6 +726,8 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -763,6 +785,8 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, + # The original model is float32, keep it for numerical stability. + "dtype": "float32", }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])