From ac3cd6e83c264255703651c085e567a684112938 Mon Sep 17 00:00:00 2001 From: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com> Date: Tue, 25 Mar 2025 14:27:22 -0700 Subject: [PATCH] [core] add bucket padding to tpu_model_runner (#14995) Signed-off-by: Chenyaaang Signed-off-by: rshaw@neuralmagic.com Co-authored-by: rshaw@neuralmagic.com --- tests/v1/tpu/worker/test_tpu_model_runner.py | 22 +++++++- vllm/envs.py | 7 +++ vllm/v1/worker/tpu_model_runner.py | 53 +++++++++++++------- 3 files changed, 63 insertions(+), 19 deletions(-) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 40ae52ef05cd3..d5f812ed4d543 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -8,7 +8,9 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.tpu_model_runner import TPUModelRunner +from vllm.v1.worker.tpu_model_runner import (TPUModelRunner, + _get_padded_token_len, + _get_paddings) # Mock torch_xla module since it may not be available in the test environments torch_xla_patcher = mock.patch.dict( @@ -305,3 +307,21 @@ def test_update_states_request_unscheduled(model_runner): assert _is_req_added(model_runner, req_ids[1]) assert not _is_req_scheduled(model_runner, req_ids[1]) + + +def test_get_paddings(): + min_token_size, max_token_size, padding_gap = 16, 512, 64 + expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] + actual_paddings = _get_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings + + +def test_get_padded_token_len(): + min_token_size, max_token_size, padding_gap = 16, 512, 64 + paddings = _get_paddings(min_token_size, max_token_size, padding_gap) + assert _get_padded_token_len(paddings, 1) == 16 + assert _get_padded_token_len(paddings, 16) == 16 + assert _get_padded_token_len(paddings, 20) == 32 + assert _get_padded_token_len(paddings, 300) == 320 + assert _get_padded_token_len(paddings, 512) == 512 diff --git a/vllm/envs.py b/vllm/envs.py index f0fd20c70e3b2..b4305d9c8e22c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -97,6 +97,7 @@ if TYPE_CHECKING: VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False + VLLM_TPU_BUCKET_PADDING_GAP: int = 64 def get_default_cache_root(): @@ -627,6 +628,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION": lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"])) if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None, + + # Gap between padding buckets for the forward pass. So we have + # 8, we will run forward pass with [16, 24, 32, ...]. + "VLLM_TPU_BUCKET_PADDING_GAP": + lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) + if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 64, } # end-env-vars-definition diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0e9473a33453d..edf859f0b9463 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import bisect import time from typing import TYPE_CHECKING, Optional, cast from unittest.mock import patch @@ -170,6 +171,10 @@ class TPUModelRunner: # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) + self.num_tokens_paddings = _get_paddings( + min_token_size=16, + max_token_size=self.max_num_tokens, + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler @@ -422,7 +427,7 @@ class TPUModelRunner: # Do the padding and copy the tensors to the TPU. padded_total_num_scheduled_tokens = _get_padded_token_len( - total_num_scheduled_tokens) + self.num_tokens_paddings, total_num_scheduled_tokens) # Zero out to avoid spurious values from prev iteration (last cp chunk) self.input_ids_cpu[ total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 @@ -573,7 +578,6 @@ class TPUModelRunner: # Prepare inputs attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) - if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -764,26 +768,21 @@ class TPUModelRunner: logger.info("Compiling the model with different input shapes.") start = time.perf_counter() - num_tokens = 16 - while True: + for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) self._dummy_run(self.kv_caches, num_tokens) xm.mark_step() - if num_tokens >= self.max_num_tokens: - break - num_tokens *= 2 xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in in %.2f [secs].", end - start) logger.info("Compiling sampling with different input shapes.") start = time.perf_counter() - num_tokens = 16 hsize = self.model_config.get_hidden_size() device = self.device # Compile sampling step for different model+sampler outputs in bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. - while True: + for num_tokens in self.num_tokens_paddings: num_reqs_to_sample = MIN_NUM_SEQS dummy_hidden = torch.randn((num_tokens, hsize), device=device, @@ -805,9 +804,6 @@ class TPUModelRunner: if num_reqs_to_sample >= self.max_num_reqs: break num_reqs_to_sample *= 2 - if num_tokens >= self.max_num_tokens: - break - num_tokens *= 2 xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in in %.2f [secs].", end - start) @@ -939,12 +935,33 @@ def _get_padded_number(n: int, multiple: int) -> int: return ((n + multiple - 1) // multiple) * multiple -def _get_padded_token_len(x: int) -> int: - if x <= 16: - return 16 - return 1 << (x - 1).bit_length() - - def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() return min(res, upper_limit) + + +def _get_paddings(min_token_size: int, max_token_size: int, + padding_gap: int) -> list[int]: + """Generate a list of padding size, starting from min_token_size, + ending with a number that can cover max_token_size + first increase the size to twice, + then increase the padding size by padding_gap. + """ + paddings = [] + num = min_token_size + while num <= padding_gap: + paddings.append(num) + num *= 2 + num //= 2 + while num < max_token_size: + num += padding_gap + paddings.append(num) + return paddings + + +def _get_padded_token_len(paddings: list[int], x: int) -> int: + """Return the first element in paddings list greater or equal to x. + """ + index = bisect.bisect_left(paddings, x) + assert index < len(paddings) + return paddings[index]