From b3f2fddd172cef23b42ffaf6c226877b6588964c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Mon, 14 Apr 2025 19:01:05 +0200 Subject: [PATCH] [TPU][V1] Fix exponential padding when `max-num-batched-tokens` is not a power of 2 (#16596) Signed-off-by: NickLucche --- tests/v1/tpu/worker/test_tpu_model_runner.py | 12 ++++++++++++ vllm/v1/worker/tpu_model_runner.py | 4 +++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 8ea8c890613a3..5c7eab0b6b11b 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -299,6 +299,18 @@ def test_get_paddings(): actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings + # Exponential padding. + max_token_size, padding_gap = 1024, 0 + expected_paddings = [16, 32, 64, 128, 256, 512, 1024] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings + # Exponential padding with max_token_size not a power of two. + max_token_size = 317 + expected_paddings = [16, 32, 64, 128, 256, 512] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings def test_get_padded_token_len(): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 69251d8bbb31f..6300f16c0b3fb 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1040,9 +1040,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, if padding_gap == 0: logger.info("Using exponential token paddings:") - while num <= max_token_size: + while True: logger.info(" %d", num) paddings.append(num) + if num >= max_token_size: + break num *= 2 else: