mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[Frontend][TPU] Add TPU default max-num-batched-tokens based on device name (#17508)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
parent
e3d0a1d190
commit
87baebebd8
@ -3140,6 +3140,14 @@ def _get_and_verify_max_len(
|
|||||||
# derived length from the HF model config.
|
# derived length from the HF model config.
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
max_model_len = int(derived_max_model_len)
|
max_model_len = int(derived_max_model_len)
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
logger.warning(
|
||||||
|
"--max-model-len is not specified, "
|
||||||
|
"it's currently using model's default length %s, "
|
||||||
|
"which might be too large."
|
||||||
|
"Please input with --max-model-len based on your "
|
||||||
|
"request input length and output length, to avoid "
|
||||||
|
"unnecessary degradation.", max_model_len)
|
||||||
elif max_model_len > derived_max_model_len:
|
elif max_model_len > derived_max_model_len:
|
||||||
# Some models might have a separate key for specifying model_max_length
|
# Some models might have a separate key for specifying model_max_length
|
||||||
# that will be bigger than derived_max_model_len. We compare user input
|
# that will be bigger than derived_max_model_len. We compare user input
|
||||||
|
|||||||
@ -1441,8 +1441,8 @@ class EngineArgs:
|
|||||||
# as the platform that vLLM is running on (e.g. the case of scaling
|
# as the platform that vLLM is running on (e.g. the case of scaling
|
||||||
# vLLM with Ray) and has no GPUs. In this case we use the default
|
# vLLM with Ray) and has no GPUs. In this case we use the default
|
||||||
# values for non-H100/H200 GPUs.
|
# values for non-H100/H200 GPUs.
|
||||||
|
from vllm.platforms import current_platform
|
||||||
try:
|
try:
|
||||||
from vllm.platforms import current_platform
|
|
||||||
device_memory = current_platform.get_device_total_memory()
|
device_memory = current_platform.get_device_total_memory()
|
||||||
except Exception:
|
except Exception:
|
||||||
# This is only used to set default_max_num_batched_tokens
|
# This is only used to set default_max_num_batched_tokens
|
||||||
@ -1463,11 +1463,37 @@ class EngineArgs:
|
|||||||
}
|
}
|
||||||
default_max_num_seqs = 256
|
default_max_num_seqs = 256
|
||||||
|
|
||||||
|
# tpu specific default values.
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
default_max_num_batched_tokens_tpu = {
|
||||||
|
UsageContext.LLM_CLASS: {
|
||||||
|
'V6E': 2048,
|
||||||
|
'V5E': 1024,
|
||||||
|
'V5P': 512,
|
||||||
|
},
|
||||||
|
UsageContext.OPENAI_API_SERVER: {
|
||||||
|
'V6E': 1024,
|
||||||
|
'V5E': 512,
|
||||||
|
'V5P': 256,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
use_context_value = usage_context.value if usage_context else None
|
use_context_value = usage_context.value if usage_context else None
|
||||||
if (self.max_num_batched_tokens is None
|
if (self.max_num_batched_tokens is None
|
||||||
and usage_context in default_max_num_batched_tokens):
|
and usage_context in default_max_num_batched_tokens):
|
||||||
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
if current_platform.is_tpu():
|
||||||
usage_context]
|
chip_name = current_platform.get_device_name()
|
||||||
|
if chip_name in default_max_num_batched_tokens_tpu[
|
||||||
|
usage_context]:
|
||||||
|
self.max_num_batched_tokens = \
|
||||||
|
default_max_num_batched_tokens_tpu[
|
||||||
|
usage_context][chip_name]
|
||||||
|
else:
|
||||||
|
self.max_num_batched_tokens = \
|
||||||
|
default_max_num_batched_tokens[usage_context]
|
||||||
|
else:
|
||||||
|
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
||||||
|
usage_context]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Setting max_num_batched_tokens to %d for %s usage context.",
|
"Setting max_num_batched_tokens to %d for %s usage context.",
|
||||||
self.max_num_batched_tokens, use_context_value)
|
self.max_num_batched_tokens, use_context_value)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from tpu_info import device
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.inputs import ProcessorInputs, PromptType
|
from vllm.inputs import ProcessorInputs, PromptType
|
||||||
@ -54,7 +55,8 @@ class TpuPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
return "tpu"
|
chip_type, _ = device.get_local_chips()
|
||||||
|
return f"TPU {chip_type.name}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user