mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:35:01 +08:00
Fix more broken speculative decode tests (#17450)
Signed-off-by: Huy Do <huydhn@gmail.com>
This commit is contained in:
parent
2007d4d54f
commit
b74d888c63
@ -205,7 +205,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"block_size": 8,
|
"block_size": 16,
|
||||||
# 2 for small prompt, 256//8 for generated.
|
# 2 for small prompt, 256//8 for generated.
|
||||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
"max_model_len": (2 + 256 // 8) * 8,
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
|
|||||||
@ -267,7 +267,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"block_size": 8,
|
"block_size": 16,
|
||||||
# 2 for small prompt, 256//8 for generated.
|
# 2 for small prompt, 256//8 for generated.
|
||||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
"max_model_len": (2 + 256 // 8) * 8,
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
@ -321,7 +321,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"block_size": 8,
|
"block_size": 16,
|
||||||
# 2 for small prompt, 256//8 for generated.
|
# 2 for small prompt, 256//8 for generated.
|
||||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
"max_model_len": (2 + 256 // 8) * 8,
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
|
|||||||
@ -152,7 +152,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"block_size": 8,
|
"block_size": 16,
|
||||||
# 2 for small prompt, 256//8 for generated.
|
# 2 for small prompt, 256//8 for generated.
|
||||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
"max_model_len": (2 + 256 // 8) * 8,
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
|
|||||||
@ -51,9 +51,14 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
|||||||
def set_include_gpu_probs_tensor(self) -> None:
|
def set_include_gpu_probs_tensor(self) -> None:
|
||||||
# Need include_gpu_probs_tensor for MultiStepWorker
|
# Need include_gpu_probs_tensor for MultiStepWorker
|
||||||
self.model_runner.sampler.include_gpu_probs_tensor = True
|
self.model_runner.sampler.include_gpu_probs_tensor = True
|
||||||
|
if hasattr(self.model_runner.model, "sampler"):
|
||||||
|
(self.model_runner.model.sampler.include_gpu_probs_tensor) = True
|
||||||
|
|
||||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||||
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
|
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
|
||||||
|
if hasattr(self.model_runner.model, "sampler"):
|
||||||
|
(self.model_runner.model.sampler.should_modify_greedy_probs_inplace
|
||||||
|
) = True
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sampler_output(
|
def sampler_output(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user