mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:35:27 +08:00
[CI/Build] Avoid CUDA initialization (#8534)
This commit is contained in:
parent
e351572900
commit
6ffa3f314c
@ -1,10 +1,10 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
seed_everything)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -16,10 +16,7 @@ def main(num_tokens: int,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
|
||||
@ -10,7 +10,7 @@ from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, seed_everything
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
@ -166,7 +166,7 @@ class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
seed_everything(seed)
|
||||
self.seed = seed
|
||||
|
||||
def benchmark(
|
||||
@ -180,7 +180,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(self.seed)
|
||||
seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
create_kv_caches_with_random)
|
||||
create_kv_caches_with_random, seed_everything)
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
PARTITION_SIZE = 512
|
||||
@ -28,10 +28,7 @@ def main(
|
||||
device: str = "cuda",
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
query = torch.empty(num_seqs,
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
seed_everything)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -17,10 +17,7 @@ def main(num_tokens: int,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
|
||||
get_rope)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, seed_everything
|
||||
|
||||
|
||||
def benchmark_rope_kernels_multi_lora(
|
||||
@ -22,9 +22,7 @@ def benchmark_rope_kernels_multi_lora(
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
@ -7,6 +7,7 @@ from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
|
||||
NewGELU, QuickGELU,
|
||||
SiluAndMul)
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -34,9 +35,7 @@ def test_act_and_mul(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||
if activation == "silu":
|
||||
@ -77,9 +76,7 @@ def test_activation(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = activation[0]()
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -139,10 +139,8 @@ def test_paged_attention(
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
@ -354,10 +352,7 @@ def test_paged_attention_rocm(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
@ -506,10 +501,7 @@ def test_multi_query_kv_attention(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
|
||||
@ -45,7 +45,7 @@ def test_flash_attn(monkeypatch):
|
||||
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
|
||||
|
||||
# Unsupported CUDA arch
|
||||
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
|
||||
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
device = "cuda"
|
||||
|
||||
@ -79,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
zeros_cols = qweight_cols
|
||||
zeros_dtype = torch.int32
|
||||
|
||||
torch.manual_seed(0)
|
||||
seed_everything(0)
|
||||
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
@ -133,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size):
|
||||
qzeros_rows = scales_rows
|
||||
qzeros_cols = qweight_cols
|
||||
|
||||
torch.manual_seed(0)
|
||||
seed_everything(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols),
|
||||
dtype=input_dtype,
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn)
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -172,10 +172,7 @@ def test_paged_attention(
|
||||
blocksparse_block_size: int,
|
||||
blocksparse_head_sliding_step: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
@ -386,10 +383,7 @@ def test_varlen_blocksparse_attention_prefill(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -55,10 +56,7 @@ def test_copy_blocks(
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
@ -134,10 +132,7 @@ def test_reshape_and_cache(
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
@ -229,9 +224,7 @@ def test_reshape_and_cache_flash(
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Create a random slot mapping.
|
||||
@ -345,10 +338,8 @@ def test_swap_blocks(
|
||||
pytest.skip()
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
src_device = device if direction[0] == "cuda" else 'cpu'
|
||||
dst_device = device if direction[1] == "cuda" else 'cpu'
|
||||
@ -417,9 +408,7 @@ def test_fp8_e4m3_conversion(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
low = -224.0
|
||||
high = 224.0
|
||||
|
||||
@ -7,6 +7,7 @@ from einops import rearrange
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
@ -104,7 +105,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed_everything(0)
|
||||
if not channel_last:
|
||||
x = torch.randn(batch,
|
||||
4096 + dim + 64,
|
||||
@ -175,7 +176,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed_everything(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, device=device, dtype=itype)
|
||||
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
|
||||
|
||||
@ -15,9 +15,6 @@ CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
@ -119,7 +116,7 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
@ -157,7 +154,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
out_dtype: Type[torch.dtype],
|
||||
@ -175,7 +172,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias: bool, device: str):
|
||||
@ -207,7 +204,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias: bool):
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.attention.backends.flash_attn # noqa: F401
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@ -87,7 +88,7 @@ def test_flash_attn_with_paged_kv(
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
@ -174,7 +175,7 @@ def test_varlen_with_paged_kv(
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
|
||||
@ -4,6 +4,8 @@ import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
@ -82,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
@ -168,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
@ -266,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
head_size: int, dtype: torch.dtype, block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
@ -379,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
) -> None:
|
||||
# test doesn't work for num_heads = (16,16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
|
||||
@ -5,6 +5,7 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant)
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
@ -24,8 +25,7 @@ SEEDS = [0]
|
||||
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, scale_ub: bool,
|
||||
seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") + 1e-6 # avoid nans
|
||||
@ -49,8 +49,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
@ -67,8 +66,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
def test_fp8_quant_large(seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
|
||||
hidden_size = 1152 # Smallest hidden_size to reproduce the error
|
||||
|
||||
@ -7,6 +7,7 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
|
||||
|
||||
@ -74,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
|
||||
@torch.inference_mode()
|
||||
def test_mmvq(hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
|
||||
@ -110,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype,
|
||||
@torch.inference_mode()
|
||||
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
|
||||
@ -4,6 +4,7 @@ import torch
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import scaled_int8_quant
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
@ -44,8 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
|
||||
@ -68,8 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
@ -113,8 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
@ -140,8 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float, azp: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
@ -30,9 +31,7 @@ def test_rms_norm(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
|
||||
@ -48,7 +48,7 @@ WTYPE_ZEROPOINTS = [
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
|
||||
@ -5,6 +5,7 @@ from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def selective_state_update_ref(state,
|
||||
@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed_everything(0)
|
||||
batch_size = 2
|
||||
dim = 4
|
||||
dstate = 8
|
||||
@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed_everything(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
@ -151,7 +152,7 @@ def test_fused_marlin_moe(
|
||||
act_order: bool,
|
||||
num_bits: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
seed_everything(7)
|
||||
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -46,9 +47,8 @@ def test_rotary_embedding(
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
@ -100,9 +100,7 @@ def test_batched_rotary_embedding(
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
@ -9,7 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
|
||||
|
||||
NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||
@ -39,10 +39,7 @@ def test_contexted_kv_attention(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(0)
|
||||
seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process
|
||||
@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(0)
|
||||
seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process
|
||||
|
||||
@ -39,6 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
from .utils import DummyLoRAManager
|
||||
|
||||
@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
seq_len) -> None:
|
||||
dtype = torch.float16
|
||||
seed = 0
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
|
||||
@ -4,7 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
|
||||
whether the corresponding Triton kernel can run normally when tensor parallelism
|
||||
is set to [1, 2, 4, 8, 16, 32, 64].
|
||||
"""
|
||||
import random
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||
ref_torch_groupgemm)
|
||||
@ -145,11 +145,8 @@ def test_punica_sgmv(
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
seq_length = 128
|
||||
(
|
||||
@ -238,11 +235,8 @@ def test_punica_bgmv(
|
||||
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
|
||||
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
|
||||
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
@ -329,11 +323,9 @@ def test_punica_expand_nslices(
|
||||
):
|
||||
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
|
||||
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
seq_length = 128 if op_type == "sgmv" else 1
|
||||
(
|
||||
inputs_tensor,
|
||||
|
||||
@ -3,7 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
|
||||
under different conditions, including various batches, numbers of LoRA , and
|
||||
maximum ranks.
|
||||
"""
|
||||
import random
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||
ref_torch_groupgemm)
|
||||
@ -60,11 +60,8 @@ def test_punica_sgmv(
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
seq_length = 128
|
||||
(
|
||||
@ -153,11 +150,8 @@ def test_punica_bgmv(
|
||||
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
|
||||
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
|
||||
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
@ -244,11 +238,9 @@ def test_punica_expand_nslices(
|
||||
):
|
||||
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
|
||||
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
seq_length = 128 if op_type == "sgmv" else 1
|
||||
(
|
||||
inputs_tensor,
|
||||
|
||||
@ -2,23 +2,18 @@
|
||||
|
||||
Run `pytest tests/models/test_granite.py`.
|
||||
"""
|
||||
import importlib.metadata
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
TRANSFORMERS_VERSION = tuple(
|
||||
map(int,
|
||||
importlib.metadata.version("transformers").split(".")))
|
||||
|
||||
MODELS = [
|
||||
"ibm/PowerLM-3b",
|
||||
]
|
||||
|
||||
|
||||
# GraniteForCausalLM will be in transformers >= 4.45
|
||||
@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45),
|
||||
@pytest.mark.skipif(transformers.__version__ < "4.45",
|
||||
reason="granite model test requires transformers >= 4.45")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
|
||||
@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
assert attn._k_scale == 1.0
|
||||
assert attn._v_scale == 1.0
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability >= 89 and not force_marlin:
|
||||
if current_platform.has_device_capability(89) and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
else:
|
||||
|
||||
@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool:
|
||||
return False
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return (capability >=
|
||||
QUANTIZATION_METHODS[quant_method].get_min_capability())
|
||||
assert capability is not None
|
||||
|
||||
min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
|
||||
|
||||
return capability.to_int() >= min_capability
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -299,7 +300,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
|
||||
# either
|
||||
if torch.cuda.get_device_capability()[0] != 9:
|
||||
if not current_platform.has_device_capability(90):
|
||||
self.use_naive_attn = True
|
||||
else:
|
||||
try:
|
||||
|
||||
@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
|
||||
and current_platform.get_device_capability()[0] >= 8)
|
||||
IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
use_spda = is_hip() or is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
device = device or (torch.cuda.current_device()
|
||||
if torch.cuda.is_available() else "cpu")
|
||||
if current_platform.is_cuda_alike() else "cpu")
|
||||
device = torch.device(device)
|
||||
# NOTE: vllm CPU backend support BF16 instead of FP16.
|
||||
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
|
||||
|
||||
@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0":
|
||||
alibi_slopes=None,
|
||||
sliding_window=None):
|
||||
|
||||
cap = current_platform.get_device_capability()
|
||||
BLOCK = 128 if cap[0] >= 8 else 64
|
||||
BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
NUM_WARPS = 8
|
||||
|
||||
# need to reduce num. blocks when using fp32
|
||||
|
||||
@ -203,7 +203,7 @@ def which_attn_to_use(
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
if not current_platform.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
@ -212,7 +212,7 @@ def which_attn_to_use(
|
||||
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
if current_platform.get_device_capability()[0] < 8:
|
||||
if not current_platform.has_device_capability(80):
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
||||
get_hf_image_processor_config,
|
||||
get_hf_text_config)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
is_cpu, is_hip, is_neuron, is_openvino, is_xpu,
|
||||
is_hip, is_neuron, is_openvino, is_xpu,
|
||||
print_warning_once)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1035,20 +1035,20 @@ class DeviceConfig:
|
||||
def __init__(self, device: str = "auto") -> None:
|
||||
if device == "auto":
|
||||
# Automated device type detection
|
||||
if is_neuron():
|
||||
if current_platform.is_cuda_alike():
|
||||
self.device_type = "cuda"
|
||||
elif is_neuron():
|
||||
self.device_type = "neuron"
|
||||
elif is_openvino():
|
||||
self.device_type = "openvino"
|
||||
elif current_platform.is_tpu():
|
||||
self.device_type = "tpu"
|
||||
elif is_cpu():
|
||||
elif current_platform.is_cpu():
|
||||
self.device_type = "cpu"
|
||||
elif is_xpu():
|
||||
self.device_type = "xpu"
|
||||
else:
|
||||
# We don't call torch.cuda.is_available() here to
|
||||
# avoid initializing CUDA before workers are forked
|
||||
self.device_type = "cuda"
|
||||
raise RuntimeError("Failed to infer device type")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
self.device_type = device
|
||||
|
||||
@ -35,6 +35,7 @@ from torch.distributed import Backend, ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -191,7 +192,7 @@ class GroupCoordinator:
|
||||
assert self.cpu_group is not None
|
||||
assert self.device_group is not None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if current_platform.is_cuda_alike():
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
@ -60,6 +60,7 @@ if TYPE_CHECKING:
|
||||
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
|
||||
VLLM_PLUGINS: Optional[List[str]] = None
|
||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||
VLLM_USE_TRITON_AWQ: bool = False
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
|
||||
|
||||
|
||||
@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True) -> bool:
|
||||
capability = current_platform.get_device_capability() # type: ignore
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability is not None:
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
|
||||
@ -32,9 +32,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89
|
||||
self.use_marlin = not current_platform.has_device_capability(89)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
|
||||
@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if is_hip():
|
||||
self.use_marlin = False
|
||||
|
||||
@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
if device_capability < 80:
|
||||
return []
|
||||
@ -52,8 +53,9 @@ def _check_marlin_supported(
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
supported_types = query_marlin_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
||||
|
||||
|
||||
def is_fp8_marlin_supported():
|
||||
capability = current_platform.get_device_capability()
|
||||
return capability[0] >= 8
|
||||
return current_platform.has_device_capability(80)
|
||||
|
||||
|
||||
def apply_fp8_marlin_linear(
|
||||
|
||||
@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool:
|
||||
# cutlass is not supported on Rocm
|
||||
if is_hip():
|
||||
return False
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
|
||||
@ -97,10 +97,10 @@ def _get_quantization_config(
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability = current_platform.get_device_capability() # type: ignore
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability is not None:
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} "
|
||||
|
||||
@ -207,7 +207,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
# For Volta and Turing GPUs, use xformers instead.
|
||||
device_available = current_platform.get_device_capability()[0] >= 8
|
||||
device_available = current_platform.has_device_capability(80)
|
||||
if device_available:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
|
||||
@ -1,17 +1,13 @@
|
||||
"""Utils for model executor."""
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
|
||||
def set_weight_attrs(
|
||||
|
||||
@ -6,10 +6,10 @@ from .interface import Platform, PlatformEnum
|
||||
class CpuPlatform(Platform):
|
||||
_enum = PlatformEnum.CPU
|
||||
|
||||
@staticmethod
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "cpu"
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@ -11,7 +11,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
class CudaPlatform(Platform):
|
||||
_enum = PlatformEnum.CUDA
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_capability(physical_device_id)
|
||||
major, minor = get_physical_device_capability(physical_device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@staticmethod
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_name(physical_device_id)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
|
||||
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import enum
|
||||
from typing import Optional, Tuple
|
||||
from typing import NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum):
|
||||
UNSPECIFIED = enum.auto()
|
||||
|
||||
|
||||
class DeviceCapability(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
|
||||
def as_version_str(self) -> str:
|
||||
return f"{self.major}.{self.minor}"
|
||||
|
||||
def to_int(self) -> int:
|
||||
"""
|
||||
Express device capability as an integer ``<major><minor>``.
|
||||
|
||||
It is assumed that the minor version is always a single digit.
|
||||
"""
|
||||
assert 0 <= self.minor < 10
|
||||
return self.major * 10 + self.minor
|
||||
|
||||
|
||||
class Platform:
|
||||
_enum: PlatformEnum
|
||||
|
||||
@ -27,16 +44,47 @@ class Platform:
|
||||
def is_cpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.CPU
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]:
|
||||
def is_cuda_alike(self) -> bool:
|
||||
"""Stateless version of :func:`torch.cuda.is_available`."""
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
device_id: int = 0,
|
||||
) -> Optional[DeviceCapability]:
|
||||
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
@classmethod
|
||||
def has_device_capability(
|
||||
cls,
|
||||
capability: Union[Tuple[int, int], int],
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Test whether this platform is compatible with a device capability.
|
||||
|
||||
The ``capability`` argument can either be:
|
||||
|
||||
- A tuple ``(major, minor)``.
|
||||
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
|
||||
"""
|
||||
current_capability = cls.get_device_capability(device_id=device_id)
|
||||
if current_capability is None:
|
||||
return False
|
||||
|
||||
if isinstance(capability, tuple):
|
||||
return current_capability >= capability
|
||||
|
||||
return current_capability.to_int() >= capability
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
"""A device-specific wrapper of `torch.inference_mode`.
|
||||
|
||||
This wrapper is recommended because some hardware backends such as TPU
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -20,12 +19,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
return torch.cuda.get_device_capability(device_id)
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
@ -6,6 +6,10 @@ from .interface import Platform, PlatformEnum
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@ -8,13 +8,15 @@ from huggingface_hub import file_exists, hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
WEIGHTS_NAME = "adapter_model.bin"
|
||||
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
|
||||
|
||||
|
||||
# Get current device name based on available devices
|
||||
def infer_device() -> str:
|
||||
if torch.cuda.is_available():
|
||||
if current_platform.is_cuda_alike():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
_config_home = envs.VLLM_CONFIG_ROOT
|
||||
@ -151,7 +152,7 @@ class UsageMessage:
|
||||
usage_context: UsageContext,
|
||||
extra_kvs: Dict[str, Any]) -> None:
|
||||
# Platform information
|
||||
if torch.cuda.is_available():
|
||||
if current_platform.is_cuda_alike():
|
||||
device_property = torch.cuda.get_device_properties(0)
|
||||
self.gpu_count = torch.cuda.device_count()
|
||||
self.gpu_type = device_property.name
|
||||
|
||||
@ -5,6 +5,7 @@ import datetime
|
||||
import enum
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
@ -32,6 +33,7 @@ from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -373,6 +375,22 @@ def get_cpu_memory() -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
"""
|
||||
Set the seed of each random module.
|
||||
|
||||
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if is_xpu():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
@ -634,9 +652,7 @@ def create_kv_caches_with_random_flash(
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||
@ -678,9 +694,7 @@ def create_kv_caches_with_random(
|
||||
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||
)
|
||||
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
|
||||
@ -750,7 +764,7 @@ class CudaMemoryProfiler:
|
||||
|
||||
def current_memory_usage(self) -> float:
|
||||
# Return the memory usage in bytes.
|
||||
if torch.cuda.is_available():
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
mem = torch.cuda.max_memory_allocated(self.device)
|
||||
elif is_xpu():
|
||||
|
||||
@ -454,14 +454,20 @@ def init_worker_distributed_environment(
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
compute_capability = current_platform.get_device_capability()
|
||||
if compute_capability[0] < 8:
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||
f"{compute_capability[0]}.{compute_capability[1]}. "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the"
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user