diff --git a/vllm/config.py b/vllm/config.py index 9738d2fd0e00..1ae8673f7775 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3140,6 +3140,14 @@ def _get_and_verify_max_len( # derived length from the HF model config. if max_model_len is None: max_model_len = int(derived_max_model_len) + if current_platform.is_tpu(): + logger.warning( + "--max-model-len is not specified, " + "it's currently using model's default length %s, " + "which might be too large." + "Please input with --max-model-len based on your " + "request input length and output length, to avoid " + "unnecessary degradation.", max_model_len) elif max_model_len > derived_max_model_len: # Some models might have a separate key for specifying model_max_length # that will be bigger than derived_max_model_len. We compare user input diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0ba14c4dee04..aefba620e189 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1441,8 +1441,8 @@ class EngineArgs: # as the platform that vLLM is running on (e.g. the case of scaling # vLLM with Ray) and has no GPUs. In this case we use the default # values for non-H100/H200 GPUs. + from vllm.platforms import current_platform try: - from vllm.platforms import current_platform device_memory = current_platform.get_device_total_memory() except Exception: # This is only used to set default_max_num_batched_tokens @@ -1463,11 +1463,37 @@ class EngineArgs: } default_max_num_seqs = 256 + # tpu specific default values. + if current_platform.is_tpu(): + default_max_num_batched_tokens_tpu = { + UsageContext.LLM_CLASS: { + 'V6E': 2048, + 'V5E': 1024, + 'V5P': 512, + }, + UsageContext.OPENAI_API_SERVER: { + 'V6E': 1024, + 'V5E': 512, + 'V5P': 256, + } + } + use_context_value = usage_context.value if usage_context else None if (self.max_num_batched_tokens is None and usage_context in default_max_num_batched_tokens): - self.max_num_batched_tokens = default_max_num_batched_tokens[ - usage_context] + if current_platform.is_tpu(): + chip_name = current_platform.get_device_name() + if chip_name in default_max_num_batched_tokens_tpu[ + usage_context]: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens_tpu[ + usage_context][chip_name] + else: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens[usage_context] + else: + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context] logger.debug( "Setting max_num_batched_tokens to %d for %s usage context.", self.max_num_batched_tokens, use_context_value) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d5923557a211..9c95e6d3fa08 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch +from tpu_info import device import vllm.envs as envs from vllm.inputs import ProcessorInputs, PromptType @@ -54,7 +55,8 @@ class TpuPlatform(Platform): @classmethod def get_device_name(cls, device_id: int = 0) -> str: - return "tpu" + chip_type, _ = device.get_local_chips() + return f"TPU {chip_type.name}" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: