mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:15:01 +08:00
[CI/Build] Add test decorator for minimum GPU memory (#8925)
This commit is contained in:
parent
d081da0064
commit
26a68d5d7e
@ -63,12 +63,11 @@ def test_baichuan_lora(baichuan_lora_files):
|
|||||||
assert output2[i] == expected_lora_output[i]
|
assert output2[i] == expected_lora_output[i]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Requires multiple GPUs")
|
|
||||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
@pytest.mark.parametrize("fully_sharded", [True, False])
|
||||||
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
|
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
|
||||||
# Cannot use as it will initialize torch.cuda too early...
|
num_gpus_available, fully_sharded):
|
||||||
# if torch.cuda.device_count() < 4:
|
if num_gpus_available < 4:
|
||||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||||
|
|
||||||
llm_tp1 = vllm.LLM(MODEL_PATH,
|
llm_tp1 = vllm.LLM(MODEL_PATH,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
|
|||||||
@ -71,10 +71,10 @@ def do_sample(llm: vllm.LLM,
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("tp_size", [1])
|
@pytest.mark.parametrize("tp_size", [1])
|
||||||
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
|
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
|
||||||
# Cannot use as it will initialize torch.cuda too early...
|
tp_size):
|
||||||
# if torch.cuda.device_count() < tp_size:
|
if num_gpus_available < tp_size:
|
||||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||||
|
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
model=model.model_path,
|
model=model.model_path,
|
||||||
@ -164,11 +164,10 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.skip("Requires multiple GPUs")
|
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
|
||||||
def test_quant_model_tp_equality(tinyllama_lora_files, model):
|
model):
|
||||||
# Cannot use as it will initialize torch.cuda too early...
|
if num_gpus_available < 2:
|
||||||
# if torch.cuda.device_count() < 2:
|
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
||||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
|
||||||
|
|
||||||
llm_tp1 = vllm.LLM(
|
llm_tp1 = vllm.LLM(
|
||||||
model=model.model_path,
|
model=model.model_path,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.utils import is_cpu
|
from vllm.utils import is_cpu
|
||||||
|
|
||||||
|
from ....utils import large_gpu_test
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
@ -69,20 +70,10 @@ def test_phimoe_routing_function():
|
|||||||
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
|
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory():
|
|
||||||
try:
|
|
||||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
||||||
gpu_memory = props.total_memory / (1024**3)
|
|
||||||
return gpu_memory
|
|
||||||
except Exception:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(condition=is_cpu(),
|
@pytest.mark.skipif(condition=is_cpu(),
|
||||||
reason="This test takes a lot time to run on CPU, "
|
reason="This test takes a lot time to run on CPU, "
|
||||||
"and vllm CI's disk space is not enough for this model.")
|
"and vllm CI's disk space is not enough for this model.")
|
||||||
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
|
@large_gpu_test(min_gb=80)
|
||||||
reason="Skip this test if GPU memory is insufficient.")
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
|||||||
|
|
||||||
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||||
_VideoAssets)
|
_VideoAssets)
|
||||||
|
from ....utils import large_gpu_test
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
# Video test
|
# Video test
|
||||||
@ -164,9 +165,7 @@ def run_video_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@large_gpu_test(min_gb=48)
|
||||||
reason=
|
|
||||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"size_factors",
|
"size_factors",
|
||||||
@ -210,9 +209,7 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@large_gpu_test(min_gb=48)
|
||||||
reason=
|
|
||||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sizes",
|
"sizes",
|
||||||
@ -306,9 +303,7 @@ def run_image_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@large_gpu_test(min_gb=48)
|
||||||
reason=
|
|
||||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
|||||||
from vllm.multimodal import MultiModalDataBuiltins
|
from vllm.multimodal import MultiModalDataBuiltins
|
||||||
from vllm.sequence import Logprob, SampleLogprobs
|
from vllm.sequence import Logprob, SampleLogprobs
|
||||||
|
|
||||||
from ....utils import VLLM_PATH
|
from ....utils import VLLM_PATH, large_gpu_test
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -121,10 +121,7 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
|
|||||||
for tokens, text, logprobs in json_data]
|
for tokens, text, logprobs in json_data]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@large_gpu_test(min_gb=80)
|
||||||
reason=
|
|
||||||
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
|
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@ -157,10 +154,7 @@ def test_chat(
|
|||||||
name_1="output")
|
name_1="output")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@large_gpu_test(min_gb=80)
|
||||||
reason=
|
|
||||||
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from vllm.sequence import SampleLogprobs
|
|||||||
|
|
||||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||||
_ImageAssets)
|
_ImageAssets)
|
||||||
|
from ....utils import large_gpu_test
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
_LIMIT_IMAGE_PER_PROMPT = 1
|
_LIMIT_IMAGE_PER_PROMPT = 1
|
||||||
@ -227,29 +228,26 @@ def _run_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SIZES = [
|
@large_gpu_test(min_gb=48)
|
||||||
# Text only
|
|
||||||
[],
|
|
||||||
# Single-size
|
|
||||||
[(512, 512)],
|
|
||||||
# Single-size, batched
|
|
||||||
[(512, 512), (512, 512), (512, 512)],
|
|
||||||
# Multi-size, batched
|
|
||||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
|
||||||
(1024, 1024), (512, 1536), (512, 2028)],
|
|
||||||
# Multi-size, batched, including text only
|
|
||||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
|
||||||
(1024, 1024), (512, 1536), (512, 2028), None],
|
|
||||||
# mllama has 8 possible aspect ratios, carefully set the sizes
|
|
||||||
# to cover all of them
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
|
||||||
reason=
|
|
||||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize("sizes", SIZES)
|
@pytest.mark.parametrize(
|
||||||
|
"sizes",
|
||||||
|
[
|
||||||
|
# Text only
|
||||||
|
[],
|
||||||
|
# Single-size
|
||||||
|
[(512, 512)],
|
||||||
|
# Single-size, batched
|
||||||
|
[(512, 512), (512, 512), (512, 512)],
|
||||||
|
# Multi-size, batched
|
||||||
|
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||||
|
(1024, 1024), (512, 1536), (512, 2028)],
|
||||||
|
# Multi-size, batched, including text only
|
||||||
|
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||||
|
(1024, 1024), (512, 1536), (512, 2028), None],
|
||||||
|
# mllama has 8 possible aspect ratios, carefully set the sizes
|
||||||
|
# to cover all of them
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
|||||||
@ -24,8 +24,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
from vllm.model_executor.model_loader.loader import get_model_loader
|
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
|
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
|
||||||
get_open_port, is_hip)
|
cuda_device_count_stateless, get_open_port, is_hip)
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
from amdsmi import (amdsmi_get_gpu_vram_usage,
|
from amdsmi import (amdsmi_get_gpu_vram_usage,
|
||||||
@ -455,6 +455,37 @@ def fork_new_process_for_each_test(
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def large_gpu_test(*, min_gb: int):
|
||||||
|
"""
|
||||||
|
Decorate a test to be skipped if no GPU is available or it does not have
|
||||||
|
sufficient memory.
|
||||||
|
|
||||||
|
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if current_platform.is_cpu():
|
||||||
|
memory_gb = 0
|
||||||
|
else:
|
||||||
|
memory_gb = current_platform.get_device_total_memory() / GB_bytes
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(
|
||||||
|
f"An error occurred when finding the available memory: {e}",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_gb = 0
|
||||||
|
|
||||||
|
test_skipif = pytest.mark.skipif(
|
||||||
|
memory_gb < min_gb,
|
||||||
|
reason=f"Need at least {memory_gb}GB GPU memory to run the test.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
|
||||||
|
return test_skipif(fork_new_process_for_each_test(f))
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def multi_gpu_test(*, num_gpus: int):
|
def multi_gpu_test(*, num_gpus: int):
|
||||||
"""
|
"""
|
||||||
Decorate a test to be run only when multiple GPUs are available.
|
Decorate a test to be run only when multiple GPUs are available.
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
from .interface import Platform, PlatformEnum
|
||||||
@ -10,6 +11,10 @@ class CpuPlatform(Platform):
|
|||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
return psutil.virtual_memory().total
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inference_mode(cls):
|
def inference_mode(cls):
|
||||||
return torch.no_grad()
|
return torch.no_grad()
|
||||||
|
|||||||
@ -59,6 +59,13 @@ def get_physical_device_name(device_id: int = 0) -> str:
|
|||||||
return pynvml.nvmlDeviceGetName(handle)
|
return pynvml.nvmlDeviceGetName(handle)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
@with_nvml_context
|
||||||
|
def get_physical_device_total_memory(device_id: int = 0) -> int:
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||||
|
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
||||||
|
|
||||||
|
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def warn_if_different_devices():
|
def warn_if_different_devices():
|
||||||
device_ids: int = pynvml.nvmlDeviceGetCount()
|
device_ids: int = pynvml.nvmlDeviceGetCount()
|
||||||
@ -107,6 +114,11 @@ class CudaPlatform(Platform):
|
|||||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||||
return get_physical_device_name(physical_device_id)
|
return get_physical_device_name(physical_device_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||||
|
return get_physical_device_total_memory(physical_device_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
||||||
|
|||||||
@ -85,6 +85,12 @@ class Platform:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
|
"""Get the name of a device."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
"""Get the total memory of a device in bytes."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -29,3 +29,8 @@ class RocmPlatform(Platform):
|
|||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
return torch.cuda.get_device_name(device_id)
|
return torch.cuda.get_device_name(device_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
device_props = torch.cuda.get_device_properties(device_id)
|
||||||
|
return device_props.total_memory
|
||||||
|
|||||||
@ -10,6 +10,10 @@ class TpuPlatform(Platform):
|
|||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inference_mode(cls):
|
def inference_mode(cls):
|
||||||
return torch.no_grad()
|
return torch.no_grad()
|
||||||
|
|||||||
@ -8,13 +8,15 @@ class XPUPlatform(Platform):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
||||||
return DeviceCapability(major=int(
|
major, minor, *_ = torch.xpu.get_device_capability(
|
||||||
torch.xpu.get_device_capability(device_id)['version'].split('.')
|
device_id)['version'].split('.')
|
||||||
[0]),
|
return DeviceCapability(major=int(major), minor=int(minor))
|
||||||
minor=int(
|
|
||||||
torch.xpu.get_device_capability(device_id)
|
|
||||||
['version'].split('.')[1]))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_name(device_id: int = 0) -> str:
|
def get_device_name(device_id: int = 0) -> str:
|
||||||
return torch.xpu.get_device_name(device_id)
|
return torch.xpu.get_device_name(device_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
device_props = torch.xpu.get_device_properties(device_id)
|
||||||
|
return device_props.total_memory
|
||||||
|
|||||||
@ -119,6 +119,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
|||||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||||
STR_INVALID_VAL: str = "INVALID"
|
STR_INVALID_VAL: str = "INVALID"
|
||||||
|
|
||||||
|
GB_bytes = 1_000_000_000
|
||||||
|
"""The number of bytes in one gigabyte (GB)."""
|
||||||
|
|
||||||
GiB_bytes = 1 << 30
|
GiB_bytes = 1 << 30
|
||||||
"""The number of bytes in one gibibyte (GiB)."""
|
"""The number of bytes in one gibibyte (GiB)."""
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user