mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:15:01 +08:00
[CI/Build] Automatically retry flaky tests (#17856)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
200da9a517
commit
6e5595ca39
@ -286,6 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
|||||||
atol=mixtral_moe_tol[dtype])
|
atol=mixtral_moe_tol[dtype])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(reruns=2)
|
||||||
@pytest.mark.parametrize("m", [1, 123, 666])
|
@pytest.mark.parametrize("m", [1, 123, 666])
|
||||||
@pytest.mark.parametrize("n", [128, 1024])
|
@pytest.mark.parametrize("n", [128, 1024])
|
||||||
@pytest.mark.parametrize("k", [256, 2048])
|
@pytest.mark.parametrize("k", [256, 2048])
|
||||||
|
|||||||
@ -1,12 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import functools
|
|
||||||
import gc
|
|
||||||
from typing import Callable, TypeVar
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
@ -25,32 +18,6 @@ def cleanup():
|
|||||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||||
|
|
||||||
|
|
||||||
_P = ParamSpec("_P")
|
|
||||||
_R = TypeVar("_R")
|
|
||||||
|
|
||||||
|
|
||||||
def retry_until_skip(n: int):
|
|
||||||
|
|
||||||
def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
||||||
for i in range(n):
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except AssertionError:
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if i == n - 1:
|
|
||||||
pytest.skip(f"Skipping test after {n} attempts.")
|
|
||||||
|
|
||||||
raise AssertionError("Code should not be reached")
|
|
||||||
|
|
||||||
return wrapper_retry
|
|
||||||
|
|
||||||
return decorator_retry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def tensorizer_config():
|
def tensorizer_config():
|
||||||
config = TensorizerConfig(tensorizer_uri="vllm")
|
config = TensorizerConfig(tensorizer_uri="vllm")
|
||||||
|
|||||||
@ -28,7 +28,6 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
|||||||
from vllm.utils import PlaceholderModule, import_from_path
|
from vllm.utils import PlaceholderModule, import_from_path
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
from .conftest import retry_until_skip
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tensorizer import EncryptionParams
|
from tensorizer import EncryptionParams
|
||||||
@ -325,7 +324,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
|
|||||||
assert outputs == deserialized_outputs
|
assert outputs == deserialized_outputs
|
||||||
|
|
||||||
|
|
||||||
@retry_until_skip(3)
|
@pytest.mark.flaky(reruns=3)
|
||||||
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user