mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:45:00 +08:00
[Frontend][TPU] Add TPU default max-num-batched-tokens based on device name (#17508)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
parent
e3d0a1d190
commit
87baebebd8
@ -3140,6 +3140,14 @@ def _get_and_verify_max_len(
|
||||
# derived length from the HF model config.
|
||||
if max_model_len is None:
|
||||
max_model_len = int(derived_max_model_len)
|
||||
if current_platform.is_tpu():
|
||||
logger.warning(
|
||||
"--max-model-len is not specified, "
|
||||
"it's currently using model's default length %s, "
|
||||
"which might be too large."
|
||||
"Please input with --max-model-len based on your "
|
||||
"request input length and output length, to avoid "
|
||||
"unnecessary degradation.", max_model_len)
|
||||
elif max_model_len > derived_max_model_len:
|
||||
# Some models might have a separate key for specifying model_max_length
|
||||
# that will be bigger than derived_max_model_len. We compare user input
|
||||
|
||||
@ -1441,8 +1441,8 @@ class EngineArgs:
|
||||
# as the platform that vLLM is running on (e.g. the case of scaling
|
||||
# vLLM with Ray) and has no GPUs. In this case we use the default
|
||||
# values for non-H100/H200 GPUs.
|
||||
from vllm.platforms import current_platform
|
||||
try:
|
||||
from vllm.platforms import current_platform
|
||||
device_memory = current_platform.get_device_total_memory()
|
||||
except Exception:
|
||||
# This is only used to set default_max_num_batched_tokens
|
||||
@ -1463,11 +1463,37 @@ class EngineArgs:
|
||||
}
|
||||
default_max_num_seqs = 256
|
||||
|
||||
# tpu specific default values.
|
||||
if current_platform.is_tpu():
|
||||
default_max_num_batched_tokens_tpu = {
|
||||
UsageContext.LLM_CLASS: {
|
||||
'V6E': 2048,
|
||||
'V5E': 1024,
|
||||
'V5P': 512,
|
||||
},
|
||||
UsageContext.OPENAI_API_SERVER: {
|
||||
'V6E': 1024,
|
||||
'V5E': 512,
|
||||
'V5P': 256,
|
||||
}
|
||||
}
|
||||
|
||||
use_context_value = usage_context.value if usage_context else None
|
||||
if (self.max_num_batched_tokens is None
|
||||
and usage_context in default_max_num_batched_tokens):
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
||||
usage_context]
|
||||
if current_platform.is_tpu():
|
||||
chip_name = current_platform.get_device_name()
|
||||
if chip_name in default_max_num_batched_tokens_tpu[
|
||||
usage_context]:
|
||||
self.max_num_batched_tokens = \
|
||||
default_max_num_batched_tokens_tpu[
|
||||
usage_context][chip_name]
|
||||
else:
|
||||
self.max_num_batched_tokens = \
|
||||
default_max_num_batched_tokens[usage_context]
|
||||
else:
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
||||
usage_context]
|
||||
logger.debug(
|
||||
"Setting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens, use_context_value)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from tpu_info import device
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
@ -54,7 +55,8 @@ class TpuPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "tpu"
|
||||
chip_type, _ = device.get_local_chips()
|
||||
return f"TPU {chip_type.name}"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user