mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
[Bug] Fix batch invariant test has to is (#27032)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
013abde6ef
commit
2ed8b6b3d0
@ -10,6 +10,11 @@ import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
hopper_only = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
|
||||
reason="Requires CUDA and Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode():
|
||||
@ -66,10 +71,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
return base_prompt
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@hopper_only
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
"""
|
||||
@ -214,14 +216,7 @@ def _extract_step_logprobs(request_output):
|
||||
return None, None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
@hopper_only
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
@ -436,10 +431,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
pytest.fail(msg)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@hopper_only
|
||||
def test_simple_generation():
|
||||
"""
|
||||
Simple test that runs the model with a basic prompt and prints the output.
|
||||
@ -485,14 +477,7 @@ def test_simple_generation():
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
@hopper_only
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
@ -707,14 +692,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
@hopper_only
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.forked
|
||||
def test_decode_logprobs_match_prefill_logprobs(backend):
|
||||
|
||||
@ -14,14 +14,13 @@ from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_no
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
hopper_only = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
|
||||
reason="Requires CUDA and Hopper (SM90)",
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
|
||||
@hopper_only
|
||||
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@ -70,13 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@hopper_only
|
||||
@pytest.mark.parametrize("batch_size", [1, 16, 128])
|
||||
@pytest.mark.parametrize("seq_len", [1, 32, 512])
|
||||
@pytest.mark.parametrize("hidden_size", [2048, 4096])
|
||||
@ -118,13 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@hopper_only
|
||||
def test_rms_norm_numerical_stability():
|
||||
"""
|
||||
Test RMS norm numerical stability with extreme values.
|
||||
@ -184,13 +171,7 @@ def test_rms_norm_numerical_stability():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@hopper_only
|
||||
def test_rms_norm_formula():
|
||||
"""
|
||||
Test that RMS norm follows the correct mathematical formula.
|
||||
@ -223,13 +204,7 @@ def test_rms_norm_formula():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@hopper_only
|
||||
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
|
||||
def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||
"""
|
||||
@ -267,13 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@hopper_only
|
||||
def test_rms_norm_determinism():
|
||||
"""
|
||||
Test that batch-invariant RMS norm produces deterministic results.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user