[Frontend][TPU] Add TPU default max-num-batched-tokens based on device name (#17508)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang 2025-05-02 21:42:44 -07:00 committed by GitHub
parent e3d0a1d190
commit 87baebebd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 4 deletions

View File

@ -3140,6 +3140,14 @@ def _get_and_verify_max_len(
# derived length from the HF model config. # derived length from the HF model config.
if max_model_len is None: if max_model_len is None:
max_model_len = int(derived_max_model_len) max_model_len = int(derived_max_model_len)
if current_platform.is_tpu():
logger.warning(
"--max-model-len is not specified, "
"it's currently using model's default length %s, "
"which might be too large."
"Please input with --max-model-len based on your "
"request input length and output length, to avoid "
"unnecessary degradation.", max_model_len)
elif max_model_len > derived_max_model_len: elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length # Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input # that will be bigger than derived_max_model_len. We compare user input

View File

@ -1441,8 +1441,8 @@ class EngineArgs:
# as the platform that vLLM is running on (e.g. the case of scaling # as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default # vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs. # values for non-H100/H200 GPUs.
from vllm.platforms import current_platform
try: try:
from vllm.platforms import current_platform
device_memory = current_platform.get_device_total_memory() device_memory = current_platform.get_device_total_memory()
except Exception: except Exception:
# This is only used to set default_max_num_batched_tokens # This is only used to set default_max_num_batched_tokens
@ -1463,11 +1463,37 @@ class EngineArgs:
} }
default_max_num_seqs = 256 default_max_num_seqs = 256
# tpu specific default values.
if current_platform.is_tpu():
default_max_num_batched_tokens_tpu = {
UsageContext.LLM_CLASS: {
'V6E': 2048,
'V5E': 1024,
'V5P': 512,
},
UsageContext.OPENAI_API_SERVER: {
'V6E': 1024,
'V5E': 512,
'V5P': 256,
}
}
use_context_value = usage_context.value if usage_context else None use_context_value = usage_context.value if usage_context else None
if (self.max_num_batched_tokens is None if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens): and usage_context in default_max_num_batched_tokens):
self.max_num_batched_tokens = default_max_num_batched_tokens[ if current_platform.is_tpu():
usage_context] chip_name = current_platform.get_device_name()
if chip_name in default_max_num_batched_tokens_tpu[
usage_context]:
self.max_num_batched_tokens = \
default_max_num_batched_tokens_tpu[
usage_context][chip_name]
else:
self.max_num_batched_tokens = \
default_max_num_batched_tokens[usage_context]
else:
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
logger.debug( logger.debug(
"Setting max_num_batched_tokens to %d for %s usage context.", "Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, use_context_value) self.max_num_batched_tokens, use_context_value)

View File

@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
from tpu_info import device
import vllm.envs as envs import vllm.envs as envs
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
@ -54,7 +55,8 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return "tpu" chip_type, _ = device.get_local_chips()
return f"TPU {chip_type.name}"
@classmethod @classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int: