[TPU][V1] Implicitly adjust page size when there's SMEM OOM (#16871)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-04-21 14:43:13 -07:00 committed by GitHub
parent 3a0fba5cf4
commit 471fe65630
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 2 deletions

View File

@ -22,6 +22,7 @@ MODELS = [
] ]
TENSOR_PARALLEL_SIZES = [1] TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]
# TODO: Enable when CI/CD will have a multi-tpu instance # TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4] # TENSOR_PARALLEL_SIZES = [1, 4]
@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic( def test_basic(
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
tensor_parallel_size: int, tensor_parallel_size: int,
max_num_seqs: int,
) -> None: ) -> None:
prompt = "The next numbers of the sequence " + ", ".join( prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:" str(i) for i in range(1024)) + " are:"
@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to # Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt # actually test chunked prompt
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
max_model_len=8196, max_model_len=8192,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
max_num_seqs=16, max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model: tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)

View File

@ -97,6 +97,20 @@ class TpuPlatform(Platform):
"Using bfloat16 instead.", vllm_config.model_config.dtype) "Using bfloat16 instead.", vllm_config.model_config.dtype)
vllm_config.model_config.dtype = torch.bfloat16 vllm_config.model_config.dtype = torch.bfloat16
if envs.VLLM_USE_V1:
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > vllm_config.cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
vllm_config.cache_config.block_size,
min_page_size,
)
vllm_config.cache_config.block_size = min_page_size
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":

View File

@ -10,7 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv
logger = init_logger(__name__) logger = init_logger(__name__)
@ -50,6 +52,19 @@ class PallasAttentionBackend(AttentionBackend):
) -> None: ) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.") raise RuntimeError("swap_blocks is not used for the TPU backend.")
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (1024 * 1024 // 2 //
vllm_config.scheduler_config.max_num_seqs // 4)
min_page_size = cdiv(vllm_config.model_config.max_model_len,
max_num_page_per_req)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
@dataclass @dataclass
class PallasMetadata: class PallasMetadata: