diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index 67ef0eb7ed666..8c4e77fd8acf1 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -59,12 +59,15 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: # Pick a random template base_prompt = random.choice(prompt_templates) - # Add some padding to vary the length if needed - if min_words > 50: + if max_words < min_words: + max_words = min_words + target_words = random.randint(min_words, max_words) + + if target_words > 50: # For longer prompts, repeat context padding_text = ( " This is an interesting topic that deserves more explanation. " - * (min_words // 50) + * (target_words // 50) ) base_prompt = base_prompt + padding_text @@ -516,8 +519,20 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): dtype="bfloat16", ) - # Use more realistic prompts for better token generation - prompts = [_random_prompt(10, 50) for i in range(32)] + # build ragged prompts to change shapes significantly across BS=1 vs BS=N + long_min = int(os.getenv("VLLM_MIN_PROMPT", "768")) + long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048")) + prompts: list[str] = [] + options = [ + (max(long_min, 1536), max(long_max, 3072)), # very long + (max(1024, long_min), max(2048, long_max)), # long + (256, 512), # mid + (10, 20), # short + ] + + for _ in range(32): + lo, hi = random.choice(options) + prompts.append(_random_prompt(lo, hi)) sp = SamplingParams( temperature=0.6,