mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[core] add bucket padding to tpu_model_runner (#14995)
Signed-off-by: Chenyaaang <llccyy1212@gmail.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
082ab86f5f
commit
ac3cd6e83c
@ -8,7 +8,9 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
|
||||
_get_padded_token_len,
|
||||
_get_paddings)
|
||||
|
||||
# Mock torch_xla module since it may not be available in the test environments
|
||||
torch_xla_patcher = mock.patch.dict(
|
||||
@ -305,3 +307,21 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[1])
|
||||
assert not _is_req_scheduled(model_runner, req_ids[1])
|
||||
|
||||
|
||||
def test_get_paddings():
|
||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||||
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
|
||||
actual_paddings = _get_paddings(min_token_size, max_token_size,
|
||||
padding_gap)
|
||||
assert actual_paddings == expected_paddings
|
||||
|
||||
|
||||
def test_get_padded_token_len():
|
||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||||
paddings = _get_paddings(min_token_size, max_token_size, padding_gap)
|
||||
assert _get_padded_token_len(paddings, 1) == 16
|
||||
assert _get_padded_token_len(paddings, 16) == 16
|
||||
assert _get_padded_token_len(paddings, 20) == 32
|
||||
assert _get_padded_token_len(paddings, 300) == 320
|
||||
assert _get_padded_token_len(paddings, 512) == 512
|
||||
|
||||
@ -97,6 +97,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 64
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -627,6 +628,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
|
||||
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
|
||||
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
|
||||
|
||||
# Gap between padding buckets for the forward pass. So we have
|
||||
# 8, we will run forward pass with [16, 24, 32, ...].
|
||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
|
||||
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 64,
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import bisect
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from unittest.mock import patch
|
||||
@ -170,6 +171,10 @@ class TPUModelRunner:
|
||||
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
||||
# Used to initialize positions / context_lens / seq_lens
|
||||
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||
self.num_tokens_paddings = _get_paddings(
|
||||
min_token_size=16,
|
||||
max_token_size=self.max_num_tokens,
|
||||
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
@ -422,7 +427,7 @@ class TPUModelRunner:
|
||||
|
||||
# Do the padding and copy the tensors to the TPU.
|
||||
padded_total_num_scheduled_tokens = _get_padded_token_len(
|
||||
total_num_scheduled_tokens)
|
||||
self.num_tokens_paddings, total_num_scheduled_tokens)
|
||||
# Zero out to avoid spurious values from prev iteration (last cp chunk)
|
||||
self.input_ids_cpu[
|
||||
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
|
||||
@ -573,7 +578,6 @@ class TPUModelRunner:
|
||||
|
||||
# Prepare inputs
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
@ -764,26 +768,21 @@ class TPUModelRunner:
|
||||
logger.info("Compiling the model with different input shapes.")
|
||||
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
while True:
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
xm.mark_step()
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
|
||||
logger.info("Compiling sampling with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
hsize = self.model_config.get_hidden_size()
|
||||
device = self.device
|
||||
# Compile sampling step for different model+sampler outputs in bucketed
|
||||
# n_tokens x max_num_reqs. Graph is really small so this is fine.
|
||||
while True:
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
num_reqs_to_sample = MIN_NUM_SEQS
|
||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||
device=device,
|
||||
@ -805,9 +804,6 @@ class TPUModelRunner:
|
||||
if num_reqs_to_sample >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs_to_sample *= 2
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
@ -939,12 +935,33 @@ def _get_padded_number(n: int, multiple: int) -> int:
|
||||
return ((n + multiple - 1) // multiple) * multiple
|
||||
|
||||
|
||||
def _get_padded_token_len(x: int) -> int:
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
||||
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
|
||||
return min(res, upper_limit)
|
||||
|
||||
|
||||
def _get_paddings(min_token_size: int, max_token_size: int,
|
||||
padding_gap: int) -> list[int]:
|
||||
"""Generate a list of padding size, starting from min_token_size,
|
||||
ending with a number that can cover max_token_size
|
||||
first increase the size to twice,
|
||||
then increase the padding size by padding_gap.
|
||||
"""
|
||||
paddings = []
|
||||
num = min_token_size
|
||||
while num <= padding_gap:
|
||||
paddings.append(num)
|
||||
num *= 2
|
||||
num //= 2
|
||||
while num < max_token_size:
|
||||
num += padding_gap
|
||||
paddings.append(num)
|
||||
return paddings
|
||||
|
||||
|
||||
def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
||||
"""Return the first element in paddings list greater or equal to x.
|
||||
"""
|
||||
index = bisect.bisect_left(paddings, x)
|
||||
assert index < len(paddings)
|
||||
return paddings[index]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user