diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0ff6a6fbbc1c1..a3b34f4ba6544 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1438,11 +1438,15 @@ class EngineArgs: from vllm.platforms import current_platform try: device_memory = current_platform.get_device_total_memory() + device_name = current_platform.get_device_name().lower() except Exception: # This is only used to set default_max_num_batched_tokens device_memory = 0 - if device_memory >= 70 * GiB_bytes: + # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces + # throughput, see PR #17885 for more details. + # So here we do an extra device name check to prevent such regression. + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { UsageContext.LLM_CLASS: 16384,