diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index cf632f1469893..e96759ed66a79 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -53,10 +53,12 @@ def test_defaults_with_usage_context(): vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS) from vllm.platforms import current_platform + from vllm.utils.mem_constants import GiB_bytes + device_memory = current_platform.get_device_total_memory() device_name = current_platform.get_device_name().lower() - if "h100" in device_name or "h200" in device_name: - # For H100 and H200, we use larger default values. + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: + # For GPUs like H100, H200, and MI300x with >= 70GB memory default_llm_tokens = 16384 default_server_tokens = 8192 default_max_num_seqs = 1024