mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:44:57 +08:00
[Test] Batch Invariant: Rename and organize tests (#27421)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
95ae50b7d1
commit
a289cc1dde
11
tests/v1/determinism/conftest.py
Normal file
11
tests/v1/determinism/conftest.py
Normal file
@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
yield
|
||||
@ -6,66 +6,9 @@ import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
yield
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Generate more realistic prompts that will actually produce varied tokens
|
||||
# Use a mix of common English text patterns
|
||||
|
||||
prompt_templates = [
|
||||
# Question-answer style
|
||||
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||
# Story/narrative style
|
||||
"Once upon a time in a distant galaxy, there lived",
|
||||
"The old man walked slowly down the street, remembering",
|
||||
"In the year 2157, humanity finally discovered",
|
||||
# Technical/code style
|
||||
"To implement a binary search tree in Python, first we need to",
|
||||
"The algorithm works by iterating through the array and",
|
||||
"Here's how to optimize database queries using indexing:",
|
||||
# Factual/informative style
|
||||
"The Renaissance was a period in European history that",
|
||||
"Climate change is caused by several factors including",
|
||||
"The human brain contains approximately 86 billion neurons which",
|
||||
# Conversational style
|
||||
"I've been thinking about getting a new laptop because",
|
||||
"Yesterday I went to the store and bought",
|
||||
"My favorite thing about summer is definitely",
|
||||
]
|
||||
|
||||
# Pick a random template
|
||||
base_prompt = random.choice(prompt_templates)
|
||||
|
||||
if max_words < min_words:
|
||||
max_words = min_words
|
||||
target_words = random.randint(min_words, max_words)
|
||||
|
||||
if target_words > 50:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. "
|
||||
* (target_words // 50)
|
||||
)
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@ -204,22 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
llm_bsN.shutdown()
|
||||
|
||||
|
||||
def _extract_step_logprobs(request_output):
|
||||
if getattr(request_output, "outputs", None):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t, inner.token_ids
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
161
tests/v1/determinism/test_online_batch_invariance.py
Normal file
161
tests/v1/determinism/test_online_batch_invariance.py
Normal file
@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
HTTP-based batch invariance test: send requests to a running
|
||||
vLLM server and compare BS=1 vs BS=N results (tokens and per-step logprobs).
|
||||
|
||||
Environment variables:
|
||||
- VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B / DeepSeek-R1)
|
||||
- VLLM_TP_SIZE: tensor parallelism size (e.g., 4)
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from utils import _random_prompt, skip_unsupported
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
def _request_completion(
|
||||
client: openai.OpenAI,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
sp: dict[str, Any],
|
||||
max_retries: int = 3,
|
||||
retry_backoff: float = 0.5,
|
||||
) -> dict[str, Any] | None:
|
||||
payload: dict[str, Any] = {"model": model, "prompt": prompt}
|
||||
payload.update(sp)
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
completion = client.completions.create(**payload)
|
||||
# Convert to plain dict so downstream logic can keep using
|
||||
# dict-style access just like with raw HTTP JSON.
|
||||
return completion.model_dump()
|
||||
except Exception as e: # pragma: no cover
|
||||
if attempt < max_retries:
|
||||
import time as _t
|
||||
|
||||
_t.sleep(retry_backoff * (2**attempt))
|
||||
continue
|
||||
sys.stderr.write(f"Error: {e}\n")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _extract_tokens_and_logprobs(
|
||||
choice: dict[str, Any],
|
||||
) -> tuple[list[Any], list[float] | None]:
|
||||
tokens: list[Any] = []
|
||||
token_logprobs: list[float] | None = None
|
||||
lp = choice.get("logprobs")
|
||||
if lp and isinstance(lp, dict):
|
||||
tokens = lp.get("token_ids") or lp.get("tokens") or []
|
||||
token_logprobs = lp.get("token_logprobs", None)
|
||||
return tokens, token_logprobs
|
||||
|
||||
|
||||
def _compare_bs1_vs_bsn_single_process(
|
||||
prompts: list[str],
|
||||
sp_kwargs: dict[str, Any],
|
||||
client: openai.OpenAI,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
# BS=1
|
||||
bs1_tokens_per_prompt: list[list[Any]] = []
|
||||
bs1_logprobs_per_prompt: list[list[float] | None] = []
|
||||
for p in prompts:
|
||||
resp = _request_completion(client, model_name, p, sp_kwargs)
|
||||
if resp is None or not resp.get("choices"):
|
||||
raise AssertionError("BS=1 empty/failed response")
|
||||
choice = resp["choices"][0]
|
||||
toks, lps = _extract_tokens_and_logprobs(choice)
|
||||
if lps is None:
|
||||
raise AssertionError(
|
||||
"logprobs not returned; ensure server supports 'logprobs'"
|
||||
)
|
||||
bs1_tokens_per_prompt.append(list(toks))
|
||||
bs1_logprobs_per_prompt.append(list(lps))
|
||||
|
||||
# BS=N
|
||||
bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
|
||||
bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
|
||||
resp = _request_completion(client, model_name, prompts, sp_kwargs)
|
||||
if resp is None or not resp.get("choices"):
|
||||
raise AssertionError("BS=N empty/failed batched response")
|
||||
choices = resp.get("choices", [])
|
||||
if len(choices) != len(prompts):
|
||||
raise AssertionError(
|
||||
f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
|
||||
)
|
||||
for idx, choice in enumerate(choices):
|
||||
toks, lps = _extract_tokens_and_logprobs(choice)
|
||||
if lps is None:
|
||||
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
|
||||
bsN_tokens_per_prompt[idx] = list(toks)
|
||||
bsN_logprobs_per_prompt[idx] = list(lps)
|
||||
|
||||
# compare
|
||||
for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
)
|
||||
):
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
raise AssertionError(
|
||||
f"Prompt {i} (sampling): Different tokens sampled. "
|
||||
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
|
||||
)
|
||||
if logprobs_bs1 is None or logprobs_bsN is None:
|
||||
raise AssertionError(f"Prompt {i}: Missing logprobs in one of the runs")
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
raise AssertionError(
|
||||
f"Prompt {i}: Different number of steps: "
|
||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)."
|
||||
)
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a != b:
|
||||
diff = abs(a - b)
|
||||
raise AssertionError(
|
||||
f"Prompt {i} Step {t}: Bitwise mismatch "
|
||||
f"(abs diff={diff:.6e}). "
|
||||
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
|
||||
)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
|
||||
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||
|
||||
sp_kwargs: dict[str, Any] = {
|
||||
"temperature": 0.6,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": 8,
|
||||
"seed": 42,
|
||||
"logprobs": 5,
|
||||
}
|
||||
|
||||
tp_size = os.getenv("VLLM_TP_SIZE", "1")
|
||||
server_args: list[str] = []
|
||||
if tp_size:
|
||||
server_args += ["-tp", tp_size]
|
||||
|
||||
with RemoteOpenAIServer(model_name, server_args) as server:
|
||||
client = server.get_client()
|
||||
_compare_bs1_vs_bsn_single_process(
|
||||
prompts=prompts_all,
|
||||
sp_kwargs=sp_kwargs,
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
)
|
||||
@ -9,15 +9,10 @@ with the standard CUDA-based implementation to ensure numerical accuracy.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils import skip_unsupported
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
74
tests/v1/determinism/utils.py
Normal file
74
tests/v1/determinism/utils.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Generate more realistic prompts that will actually produce varied tokens
|
||||
# Use a mix of common English text patterns
|
||||
|
||||
prompt_templates = [
|
||||
# Question-answer style
|
||||
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||
# Story/narrative style
|
||||
"Once upon a time in a distant galaxy, there lived",
|
||||
"The old man walked slowly down the street, remembering",
|
||||
"In the year 2157, humanity finally discovered",
|
||||
# Technical/code style
|
||||
"To implement a binary search tree in Python, first we need to",
|
||||
"The algorithm works by iterating through the array and",
|
||||
"Here's how to optimize database queries using indexing:",
|
||||
# Factual/informative style
|
||||
"The Renaissance was a period in European history that",
|
||||
"Climate change is caused by several factors including",
|
||||
"The human brain contains approximately 86 billion neurons which",
|
||||
# Conversational style
|
||||
"I've been thinking about getting a new laptop because",
|
||||
"Yesterday I went to the store and bought",
|
||||
"My favorite thing about summer is definitely",
|
||||
]
|
||||
|
||||
# Pick a random template
|
||||
base_prompt = random.choice(prompt_templates)
|
||||
|
||||
if max_words < min_words:
|
||||
max_words = min_words
|
||||
target_words = random.randint(min_words, max_words)
|
||||
|
||||
if target_words > 50:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. "
|
||||
* (target_words // 50)
|
||||
)
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
def _extract_step_logprobs(request_output):
|
||||
if getattr(request_output, "outputs", None):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t, inner.token_ids
|
||||
|
||||
return None, None
|
||||
Loading…
x
Reference in New Issue
Block a user