diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index dc9c69bf58b9..3bd0b6609d88 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -260,13 +260,18 @@ def test_deep_sleep_fp8_kvcache(): llm.sleep(level=2) used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline - assert used_bytes < 3 * GiB_bytes + + # Rocm uses more memory for CudaGraphs, so we add 2 GiB more for the threshold + rocm_extra_mem_bytes = 2 * GiB_bytes if current_platform.is_rocm() else 0 + mem_threshold_after_sleep = 3 * GiB_bytes + rocm_extra_mem_bytes + assert used_bytes < mem_threshold_after_sleep llm.wake_up(tags=["weights"]) llm.collective_rpc("reload_weights") used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline - assert used_bytes < 4 * GiB_bytes + mem_threshold_after_wake_up = 4 * GiB_bytes + rocm_extra_mem_bytes + assert used_bytes < mem_threshold_after_wake_up # now allocate kv cache and cuda graph memory llm.wake_up(tags=["kv_cache"])