diff --git a/vllm/envs.py b/vllm/envs.py index 4c413006a6413..46c5b3a1dc5d0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -99,7 +99,7 @@ if TYPE_CHECKING: VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False - VLLM_TPU_BUCKET_PADDING_GAP: int = 64 + VLLM_TPU_BUCKET_PADDING_GAP: int = 0 def get_default_cache_root(): @@ -648,7 +648,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # 8, we will run forward pass with [16, 24, 32, ...]. "VLLM_TPU_BUCKET_PADDING_GAP": lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) - if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 64, + if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, } # end-env-vars-definition diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index edf859f0b9463..cf5c56b98beaa 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -944,18 +944,35 @@ def _get_paddings(min_token_size: int, max_token_size: int, padding_gap: int) -> list[int]: """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size - first increase the size to twice, - then increase the padding size by padding_gap. + + If padding_gap == 0 then: + increase 2X each time (exponential) + else: + first increase the size to twice, + then increase the padding size by padding_gap. """ paddings = [] num = min_token_size - while num <= padding_gap: - paddings.append(num) - num *= 2 - num //= 2 - while num < max_token_size: - num += padding_gap - paddings.append(num) + + if padding_gap == 0: + logger.info("Using exponential paddings:") + while num <= max_token_size: + logger.info(" %d", num) + paddings.append(num) + num *= 2 + + else: + logger.info("Using incremental paddings:") + while num <= padding_gap: + logger.info(" %d", num) + paddings.append(num) + num *= 2 + num //= 2 + while num < max_token_size: + num += padding_gap + logger.info(" %d", num) + paddings.append(num) + return paddings