[core] add bucket padding to tpu_model_runner (#14995)

Signed-off-by: Chenyaaang <llccyy1212@gmail.com>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
Chenyaaang 2025-03-25 14:27:22 -07:00 committed by GitHub
parent 082ab86f5f
commit ac3cd6e83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 19 deletions

View File

@ -8,7 +8,9 @@ from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
_get_padded_token_len,
_get_paddings)
# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher = mock.patch.dict(
@ -305,3 +307,21 @@ def test_update_states_request_unscheduled(model_runner):
assert _is_req_added(model_runner, req_ids[1])
assert not _is_req_scheduled(model_runner, req_ids[1])
def test_get_paddings():
min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
def test_get_padded_token_len():
min_token_size, max_token_size, padding_gap = 16, 512, 64
paddings = _get_paddings(min_token_size, max_token_size, padding_gap)
assert _get_padded_token_len(paddings, 1) == 16
assert _get_padded_token_len(paddings, 16) == 16
assert _get_padded_token_len(paddings, 20) == 32
assert _get_padded_token_len(paddings, 300) == 320
assert _get_padded_token_len(paddings, 512) == 512

View File

@ -97,6 +97,7 @@ if TYPE_CHECKING:
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 64
def get_default_cache_root():
@ -627,6 +628,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 64,
}
# end-env-vars-definition

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import bisect
import time
from typing import TYPE_CHECKING, Optional, cast
from unittest.mock import patch
@ -170,6 +171,10 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_paddings(
min_token_size=16,
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
"""Update the cached states and the persistent batch with the scheduler
@ -422,7 +427,7 @@ class TPUModelRunner:
# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens = _get_padded_token_len(
total_num_scheduled_tokens)
self.num_tokens_paddings, total_num_scheduled_tokens)
# Zero out to avoid spurious values from prev iteration (last cp chunk)
self.input_ids_cpu[
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
@ -573,7 +578,6 @@ class TPUModelRunner:
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
@ -764,26 +768,21 @@ class TPUModelRunner:
logger.info("Compiling the model with different input shapes.")
start = time.perf_counter()
num_tokens = 16
while True:
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step()
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
num_tokens = 16
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
while True:
for num_tokens in self.num_tokens_paddings:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
@ -805,9 +804,6 @@ class TPUModelRunner:
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
@ -939,12 +935,33 @@ def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
def _get_padded_token_len(x: int) -> int:
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)
def _get_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
first increase the size to twice,
then increase the padding size by padding_gap.
"""
paddings = []
num = min_token_size
while num <= padding_gap:
paddings.append(num)
num *= 2
num //= 2
while num < max_token_size:
num += padding_gap
paddings.append(num)
return paddings
def _get_padded_token_len(paddings: list[int], x: int) -> int:
"""Return the first element in paddings list greater or equal to x.
"""
index = bisect.bisect_left(paddings, x)
assert index < len(paddings)
return paddings[index]