mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 13:45:51 +08:00
[Bugfix][2/n] Fix speculative decoding CI - Fix test_ngram_e2e_greedy_correctness (#19644)
This commit is contained in:
parent
e13945f9dd
commit
ee1531bc38
@ -14,10 +14,13 @@ MAIN_MODEL = "JackFram/llama-68m"
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Verify equality when cuda graphs allowed.
|
# Verify equality when cuda graphs allowed.
|
||||||
"enforce_eager": False,
|
"enforce_eager": False,
|
||||||
"model_name": "JackFram/llama-68m",
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"per_test_common_llm_kwargs",
|
"per_test_common_llm_kwargs",
|
||||||
@ -59,6 +62,9 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -117,6 +123,9 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
|||||||
@ -17,7 +17,10 @@ from .conftest import run_equality_correctness_test
|
|||||||
"model_name": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -75,6 +78,9 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -128,6 +134,9 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -182,6 +191,9 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -256,8 +268,12 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
|||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model_name": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
|||||||
@ -494,6 +494,9 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
|||||||
@ -40,6 +40,9 @@ from .conftest import run_equality_correctness_test
|
|||||||
|
|
||||||
# Print spec metrics.
|
# Print spec metrics.
|
||||||
"disable_log_stats": False,
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
{
|
{
|
||||||
@ -97,6 +100,9 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Print spec metrics.
|
# Print spec metrics.
|
||||||
"disable_log_stats": False,
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
{
|
{
|
||||||
@ -160,6 +166,9 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
{
|
{
|
||||||
@ -221,6 +230,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -281,6 +293,9 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -337,6 +352,9 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# The original model is float32, keep it for numerical stability.
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
|||||||
@ -74,6 +74,7 @@ class EAGLE(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.dtype = vllm_config.model_config.dtype
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
architectures = getattr(self.config.model, "architectures", [])
|
architectures = getattr(self.config.model, "architectures", [])
|
||||||
@ -250,7 +251,7 @@ class EAGLE(nn.Module):
|
|||||||
lm_head_weight = torch.zeros(
|
lm_head_weight = torch.zeros(
|
||||||
self.lm_head.org_vocab_size,
|
self.lm_head.org_vocab_size,
|
||||||
self.lm_head.embedding_dim,
|
self.lm_head.embedding_dim,
|
||||||
dtype=self.config.torch_dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user