[CI/Build] Avoid CUDA initialization (#8534)

This commit is contained in:
Cyrus Leung 2024-09-18 18:38:11 +08:00 committed by GitHub
parent e351572900
commit 6ffa3f314c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 256 additions and 256 deletions

View File

@ -1,10 +1,10 @@
import random
import time
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
seed_everything)
@torch.inference_mode()
@ -16,10 +16,7 @@ def main(num_tokens: int,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device("cuda")
layer = RMSNorm(hidden_size).to(dtype=dtype)

View File

@ -10,7 +10,7 @@ from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, seed_everything
class BenchmarkConfig(TypedDict):
@ -166,7 +166,7 @@ class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(seed)
seed_everything(seed)
self.seed = seed
def benchmark(
@ -180,7 +180,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(self.seed)
seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)

View File

@ -6,7 +6,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)
create_kv_caches_with_random, seed_everything)
NUM_BLOCKS = 1024
PARTITION_SIZE = 512
@ -28,10 +28,7 @@ def main(
device: str = "cuda",
kv_cache_dtype: Optional[str] = None,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs,

View File

@ -1,10 +1,10 @@
import random
import time
import torch
from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
seed_everything)
@torch.inference_mode()
@ -17,10 +17,7 @@ def main(num_tokens: int,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device("cuda")
x = torch.randn(num_tokens, hidden_size, dtype=dtype)

View File

@ -6,7 +6,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, seed_everything
def benchmark_rope_kernels_multi_lora(
@ -22,9 +22,7 @@ def benchmark_rope_kernels_multi_lora(
max_position: int = 8192,
base: int = 10000,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size

View File

@ -7,6 +7,7 @@ from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, QuickGELU,
SiluAndMul)
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol
@ -34,9 +35,7 @@ def test_act_and_mul(
seed: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu":
@ -77,9 +76,7 @@ def test_activation(
seed: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation[0]()

View File

@ -6,7 +6,7 @@ import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
@ -139,10 +139,8 @@ def test_paged_attention(
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
@ -354,10 +352,7 @@ def test_paged_attention_rocm(
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
@ -506,10 +501,7 @@ def test_multi_query_kv_attention(
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use

View File

@ -45,7 +45,7 @@ def test_flash_attn(monkeypatch):
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != STR_FLASH_ATTN_VAL

View File

@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
from vllm.utils import seed_everything
device = "cuda"
@ -79,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
zeros_cols = qweight_cols
zeros_dtype = torch.int32
torch.manual_seed(0)
seed_everything(0)
qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
@ -133,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size):
qzeros_rows = scales_rows
qzeros_cols = qweight_cols
torch.manual_seed(0)
seed_everything(0)
input = torch.rand((input_rows, input_cols),
dtype=input_dtype,

View File

@ -7,7 +7,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
@ -172,10 +172,7 @@ def test_paged_attention(
blocksparse_block_size: int,
blocksparse_head_sliding_step: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
@ -386,10 +383,7 @@ def test_varlen_blocksparse_attention_prefill(
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use

View File

@ -6,6 +6,7 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.utils import seed_everything
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -55,10 +56,7 @@ def test_copy_blocks(
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
@ -134,10 +132,7 @@ def test_reshape_and_cache(
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
@ -229,9 +224,7 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
# Create a random slot mapping.
@ -345,10 +338,8 @@ def test_swap_blocks(
pytest.skip()
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
src_device = device if direction[0] == "cuda" else 'cpu'
dst_device = device if direction[1] == "cuda" else 'cpu'
@ -417,9 +408,7 @@ def test_fp8_e4m3_conversion(
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
low = -224.0
high = 224.0

View File

@ -7,6 +7,7 @@ from einops import rearrange
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything
def causal_conv1d_ref(
@ -104,7 +105,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.random.manual_seed(0)
seed_everything(0)
if not channel_last:
x = torch.randn(batch,
4096 + dim + 64,
@ -175,7 +176,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.random.manual_seed(0)
seed_everything(0)
batch = 2
x = torch.randn(batch, dim, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)

View File

@ -15,9 +15,6 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
@ -119,7 +116,7 @@ def cutlass_int8_gemm_helper(m: int,
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
@ -157,7 +154,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype],
@ -175,7 +172,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(capability < 89,
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
use_bias: bool, device: str):
@ -207,7 +204,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias: bool):

View File

@ -4,6 +4,7 @@ import pytest
import torch
import vllm.attention.backends.flash_attn # noqa: F401
from vllm.utils import seed_everything
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
@ -87,7 +88,7 @@ def test_flash_attn_with_paged_kv(
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
@ -174,7 +175,7 @@ def test_varlen_with_paged_kv(
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]

View File

@ -4,6 +4,8 @@ import flashinfer
import pytest
import torch
from vllm.utils import seed_everything
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
@ -82,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv(
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
@ -168,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
@ -266,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
head_size: int, dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
@ -379,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
) -> None:
# test doesn't work for num_heads = (16,16)
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]

View File

@ -5,6 +5,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)
from vllm.utils import seed_everything
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
@ -24,8 +25,7 @@ SEEDS = [0]
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, scale_ub: bool,
seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans
@ -49,8 +49,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
@ -67,8 +66,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode()
@pytest.mark.parametrize("seed", SEEDS)
def test_fp8_quant_large(seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
hidden_size = 1152 # Smallest hidden_size to reproduce the error

View File

@ -7,6 +7,7 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from huggingface_hub import snapshot_download
import vllm._custom_ops as ops
from vllm.utils import seed_everything
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
@ -74,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
@torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
torch.cuda.manual_seed_all(0)
seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
@ -110,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype,
@torch.inference_mode()
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
torch.cuda.manual_seed_all(0)
seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")

View File

@ -4,6 +4,7 @@ import torch
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from tests.kernels.utils import opcheck
from vllm._custom_ops import scaled_int8_quant
from vllm.utils import seed_everything
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
@ -44,8 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
@torch.inference_mode()
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
@ -68,8 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode()
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
@ -113,8 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
@ -140,8 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float, azp: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,

View File

@ -3,6 +3,7 @@ import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils import seed_everything
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
@ -30,9 +31,7 @@ def test_rms_norm(
seed: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)

View File

@ -48,7 +48,7 @@ WTYPE_ZEROPOINTS = [
# `is_quant_method_supported` conflates kernels with quantization methods
# an assumption which is breaking down as quantizations methods can have
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
def rand_data(shape, dtype=torch.float16):

View File

@ -5,6 +5,7 @@ from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
def selective_state_update_ref(state,
@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
torch.random.manual_seed(0)
seed_everything(0)
batch_size = 2
dim = 4
dstate = 8
@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
seed_everything(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)

View File

@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types
from vllm.utils import seed_everything
def torch_moe(a, w1, w2, score, topk):
@ -151,7 +152,7 @@ def test_fused_marlin_moe(
act_order: bool,
num_bits: int,
):
torch.manual_seed(7)
seed_everything(7)
if topk > e:
return

View File

@ -5,6 +5,7 @@ import pytest
import torch
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol
@ -46,9 +47,8 @@ def test_rotary_embedding(
) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
@ -100,9 +100,7 @@ def test_batched_rotary_embedding(
max_position: int = 8192,
base: int = 10000,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora(
max_position: int = 8192,
base: int = 10000,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size

View File

@ -9,7 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
@ -39,10 +39,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str,
device: str,
) -> None:
random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process
@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str,
device: str,
) -> None:
random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process

View File

@ -39,6 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed
from vllm.utils import seed_everything
from .utils import DummyLoRAManager
@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len) -> None:
dtype = torch.float16
seed = 0
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8

View File

@ -4,7 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
import random
from unittest.mock import patch
import pytest
@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
@ -145,11 +145,8 @@ def test_punica_sgmv(
seed: int,
device: str,
):
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
seq_length = 128
(
@ -238,11 +235,8 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
seq_length = 1
(
@ -329,11 +323,9 @@ def test_punica_expand_nslices(
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
inputs_tensor,

View File

@ -3,7 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
import random
from unittest.mock import patch
import pytest
@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
@ -60,11 +60,8 @@ def test_punica_sgmv(
seed: int,
device: str,
):
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
seq_length = 128
(
@ -153,11 +150,8 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
seq_length = 1
(
@ -244,11 +238,9 @@ def test_punica_expand_nslices(
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
inputs_tensor,

View File

@ -2,23 +2,18 @@
Run `pytest tests/models/test_granite.py`.
"""
import importlib.metadata
import pytest
import transformers
from ...utils import check_logprobs_close
TRANSFORMERS_VERSION = tuple(
map(int,
importlib.metadata.version("transformers").split(".")))
MODELS = [
"ibm/PowerLM-3b",
]
# GraniteForCausalLM will be in transformers >= 4.45
@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45),
@pytest.mark.skipif(transformers.__version__ < "4.45",
reason="granite model test requires transformers >= 4.45")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])

View File

@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability >= 89 and not force_marlin:
if current_platform.has_device_capability(89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:

View File

@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool:
return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability())
assert capability is not None
min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
return capability.to_int() >= min_capability

View File

@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -299,7 +300,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
if torch.cuda.get_device_capability()[0] != 9:
if not current_platform.has_device_capability(90):
self.use_naive_attn = True
else:
try:

View File

@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and current_platform.get_device_capability()[0] >= 8)
IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
use_spda = is_hip() or is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if torch.cuda.is_available() else "cpu")
if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device)
# NOTE: vllm CPU backend support BF16 instead of FP16.
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE

View File

@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0":
alibi_slopes=None,
sliding_window=None):
cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 8
# need to reduce num. blocks when using fp32

View File

@ -203,7 +203,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if current_platform.get_device_capability()[0] != 9:
if not current_platform.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
@ -212,7 +212,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if current_platform.get_device_capability()[0] < 8:
if not current_platform.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "

View File

@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_cpu, is_hip, is_neuron, is_openvino, is_xpu,
is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once)
if TYPE_CHECKING:
@ -1035,20 +1035,20 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if is_neuron():
if current_platform.is_cuda_alike():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"
elif current_platform.is_tpu():
self.device_type = "tpu"
elif is_cpu():
elif current_platform.is_cpu():
self.device_type = "cpu"
elif is_xpu():
self.device_type = "xpu"
else:
# We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked
self.device_type = "cuda"
raise RuntimeError("Failed to infer device type")
else:
# Device type is assigned explicitly
self.device_type = device

View File

@ -35,6 +35,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
@dataclass
@ -191,7 +192,7 @@ class GroupCoordinator:
assert self.cpu_group is not None
assert self.device_group is not None
if torch.cuda.is_available():
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")

View File

@ -60,6 +60,7 @@ if TYPE_CHECKING:
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False

View File

@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
capability = current_platform.get_device_capability() # type: ignore
capability_tuple = current_platform.get_device_capability()
if capability is not None:
capability = capability[0] * 10 + capability[1]
if capability_tuple is not None:
capability = capability_tuple.to_int()
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(

View File

@ -32,9 +32,7 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
self.use_marlin = not current_platform.has_device_capability(89)
@classmethod
def get_name(cls) -> str:

View File

@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
self.use_marlin = False

View File

@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
if device_capability < 80:
return []
@ -52,8 +53,9 @@ def _check_marlin_supported(
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
supported_types = query_marlin_supported_quant_types(
has_zp, device_capability)

View File

@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def is_fp8_marlin_supported():
capability = current_platform.get_device_capability()
return capability[0] >= 8
return current_platform.has_device_capability(80)
def apply_fp8_marlin_linear(

View File

@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
if is_hip():
return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return ops.cutlass_scaled_mm_supports_fp8(capability)

View File

@ -97,10 +97,10 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = current_platform.get_device_capability() # type: ignore
capability_tuple = current_platform.get_device_capability()
if capability is not None:
capability = capability[0] * 10 + capability[1]
if capability_tuple is not None:
capability = capability_tuple.to_int()
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "

View File

@ -207,7 +207,7 @@ class Qwen2VisionAttention(nn.Module):
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
device_available = current_platform.has_device_capability(80)
if device_available:
from transformers.utils import is_flash_attn_2_available

View File

@ -1,17 +1,13 @@
"""Utils for model executor."""
import random
from typing import Any, Dict, Optional
import numpy as np
import torch
from vllm.utils import seed_everything
def set_random_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
seed_everything(seed)
def set_weight_attrs(

View File

@ -6,10 +6,10 @@ from .interface import Platform, PlatformEnum
class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
@staticmethod
def get_device_name(device_id: int = 0) -> str:
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"
@staticmethod
def inference_mode():
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@ -11,7 +11,7 @@ from typing_extensions import ParamSpec
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
from .interface import DeviceCapability, Platform, PlatformEnum
logger = init_logger(__name__)
@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA
@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id)
major, minor = get_physical_device_capability(physical_device_id)
return DeviceCapability(major=major, minor=minor)
@staticmethod
def get_device_name(device_id: int = 0) -> str:
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id)
@staticmethod
@classmethod
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""

View File

@ -1,5 +1,5 @@
import enum
from typing import Optional, Tuple
from typing import NamedTuple, Optional, Tuple, Union
import torch
@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum):
UNSPECIFIED = enum.auto()
class DeviceCapability(NamedTuple):
major: int
minor: int
def as_version_str(self) -> str:
return f"{self.major}.{self.minor}"
def to_int(self) -> int:
"""
Express device capability as an integer ``<major><minor>``.
It is assumed that the minor version is always a single digit.
"""
assert 0 <= self.minor < 10
return self.major * 10 + self.minor
class Platform:
_enum: PlatformEnum
@ -27,16 +44,47 @@ class Platform:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
@staticmethod
def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]:
def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def get_device_capability(
cls,
device_id: int = 0,
) -> Optional[DeviceCapability]:
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
return None
@staticmethod
def get_device_name(device_id: int = 0) -> str:
@classmethod
def has_device_capability(
cls,
capability: Union[Tuple[int, int], int],
device_id: int = 0,
) -> bool:
"""
Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be:
- A tuple ``(major, minor)``.
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
if isinstance(capability, tuple):
return current_capability >= capability
return current_capability.to_int() >= capability
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@staticmethod
def inference_mode():
@classmethod
def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU

View File

@ -1,12 +1,11 @@
import os
from functools import lru_cache
from typing import Tuple
import torch
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
from .interface import DeviceCapability, Platform, PlatformEnum
logger = init_logger(__name__)
@ -20,12 +19,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
@staticmethod
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device_id)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@staticmethod
@classmethod
@lru_cache(maxsize=8)
def get_device_name(device_id: int = 0) -> str:
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)

View File

@ -6,6 +6,10 @@ from .interface import Platform, PlatformEnum
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@staticmethod
def inference_mode():
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@ -8,13 +8,15 @@ from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file
from vllm.platforms import current_platform
WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
# Get current device name based on available devices
def infer_device() -> str:
if torch.cuda.is_available():
if current_platform.is_cuda_alike():
return "cuda"
return "cpu"

View File

@ -17,6 +17,7 @@ import torch
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.platforms import current_platform
from vllm.version import __version__ as VLLM_VERSION
_config_home = envs.VLLM_CONFIG_ROOT
@ -151,7 +152,7 @@ class UsageMessage:
usage_context: UsageContext,
extra_kvs: Dict[str, Any]) -> None:
# Platform information
if torch.cuda.is_available():
if current_platform.is_cuda_alike():
device_property = torch.cuda.get_device_properties(0)
self.gpu_count = torch.cuda.device_count()
self.gpu_type = device_property.name

View File

@ -5,6 +5,7 @@ import datetime
import enum
import gc
import os
import random
import socket
import subprocess
import sys
@ -32,6 +33,7 @@ from typing_extensions import ParamSpec, TypeIs, assert_never
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -373,6 +375,22 @@ def get_cpu_memory() -> int:
return psutil.virtual_memory().total
def seed_everything(seed: int) -> None:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
if current_platform.is_cuda_alike():
torch.cuda.manual_seed_all(seed)
if is_xpu():
torch.xpu.manual_seed_all(seed)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
@ -634,9 +652,7 @@ def create_kv_caches_with_random_flash(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
@ -678,9 +694,7 @@ def create_kv_caches_with_random(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
@ -750,7 +764,7 @@ class CudaMemoryProfiler:
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
if torch.cuda.is_available():
if current_platform.is_cuda_alike():
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif is_xpu():

View File

@ -454,14 +454,20 @@ def init_worker_distributed_environment(
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = current_platform.get_device_capability()
if compute_capability[0] < 8:
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}. "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")