diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 4d9a113e39ee..c8691ee87fe6 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -161,7 +161,13 @@ class TPUWorker: # intermediate activations. m = xm.get_memory_info(self.device) total_memory_size = m["bytes_limit"] - profiled = m["peak_bytes_used"] # Weights + intermediate activations. + current_mem = m["bytes_used"] + # Ideally we would use profiled = m["peak_bytes_used"] to + # get weights + activations. But there is memory used during + # compilation / weight loading that impacts the peak and + # there is no way to reset peak memory in XLA, So we + # use the heuristic of 2% of weights. + profiled = current_mem * 1.02 # Calculate the TPU KV cache size based on profiling. usable_memory_size = int(total_memory_size *