From e34d130c1613dbabc708cd5f059506c887ac81b4 Mon Sep 17 00:00:00 2001 From: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com> Date: Mon, 7 Jul 2025 22:16:16 -0700 Subject: [PATCH] [TPU] Temporary fix vmem oom for long model len by reducing page size (#20278) Signed-off-by: Chenyaaang --- vllm/v1/attention/backends/pallas.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 253d79d925cef..2921e8ed55abe 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -86,6 +86,12 @@ class PallasAttentionBackend(AttentionBackend): # spill less likely. Meanwhile we make sure the page size is in [16, 256]. @staticmethod def get_page_size(vllm_config: VllmConfig) -> int: + # TODO: This is a temporary fix for vmem OOM. + # For long model length, we use 16 page-size to avoid too much + # VMEM spill. A more robust solution should be implemented to + # handle VREG spills. + if vllm_config.model_config.max_model_len > 8192: + return 16 page_size = next_power_of_2( vllm_config.model_config.max_model_len) // 16 if page_size <= 16: