diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index 18a47ba3b5506..626129349648c 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -22,6 +22,8 @@ from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +_MAX_SPEC_LEN = 32 + @dataclass class RequestData: @@ -323,7 +325,7 @@ class RequestState: logits_indices, target_logits_indices, bonus_logits_indices, - BLOCK_SIZE=triton.next_power_of_2(32 + 1), + BLOCK_SIZE=triton.next_power_of_2(_MAX_SPEC_LEN + 1), ) draft_token_ids = input_ids[logits_indices]