diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index fc953a66f0820..1c45e7fe366ff 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -10,6 +10,7 @@ from utils import ( BACKENDS, _extract_step_logprobs, _random_prompt, + is_device_capability_below_90, resolve_model_name, skip_unsupported, ) @@ -17,6 +18,8 @@ from utils import ( import vllm.model_executor.layers.batch_invariant as batch_invariant from vllm import LLM, SamplingParams +IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90() + @skip_unsupported @pytest.mark.timeout(1000) @@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( max_model_len=8192, dtype="bfloat16", # not everything is supported gpu_memory_utilization=0.9, + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, ) # Use more realistic prompts for better token generation @@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): gpu_memory_utilization=0.9, max_model_len=2048, dtype="bfloat16", + enable_prefix_caching=False, + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, ) prompt = "the capital of france is" @@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail( max_num_seqs=32, max_model_len=8192, dtype="bfloat16", + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, ) # build ragged prompts to change shapes significantly across BS=1 vs BS=N @@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs( max_num_seqs=32, max_model_len=8192, dtype="bfloat16", + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, ) # Use a few test prompts @@ -925,6 +933,8 @@ def LLM_with_max_seqs( max_model_len=max_model_len, dtype="bfloat16", tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), + enable_prefix_caching=False, + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, # Enable for MOE models # enable_expert_parallel=True, ) diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py index 6aab50cf84ab6..a8013ed229cfc 100644 --- a/tests/v1/determinism/utils.py +++ b/tests/v1/determinism/utils.py @@ -11,8 +11,10 @@ from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer skip_unsupported = pytest.mark.skipif( - not (current_platform.is_cuda() and current_platform.has_device_capability(90)), - reason="Requires CUDA and >= Hopper (SM90)", + not (current_platform.is_cuda() and current_platform.has_device_capability(80)), + # Supports testing on Ampere and Ada Lovelace devices. + # Note: For devices with SM < 90, batch invariance does not support CUDA Graphs. + reason="Requires CUDA and >= Ampere (SM80)", ) BACKENDS: list[str] = [ @@ -97,3 +99,7 @@ def _extract_step_logprobs(request_output): return t, inner.token_ids return None, None + + +def is_device_capability_below_90() -> bool: + return not current_platform.has_device_capability(90) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 4cab47f4192a2..b14e7dad77f9a 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -935,7 +935,11 @@ def enable_batch_invariant_mode(): # Batch invariant matmuls are no longer needed after cublas overrides if not is_torch_equal_or_newer("2.10.0.dev"): - if current_platform.is_device_capability(100): + if ( + current_platform.is_device_capability(100) + or current_platform.is_device_capability(80) + or current_platform.is_device_capability(89) + ): # For PyTorch 2.9, B200 uses GEMV for bs=1 # Requires https://github.com/pytorch/pytorch/pull/166735 _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")