mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 19:00:56 +08:00
[Hardware] using current_platform.seed_everything (#9785)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
parent
09500f7dde
commit
622b7ab955
@ -3,8 +3,8 @@ import time
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
seed_everything)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -16,7 +16,7 @@ def main(num_tokens: int,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
|
||||
@ -10,7 +10,8 @@ from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.utils import FlexibleArgumentParser, seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
@ -167,7 +168,7 @@ class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
|
||||
def benchmark(
|
||||
@ -181,7 +182,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
seed_everything(self.seed)
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
|
||||
@ -5,8 +5,9 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
create_kv_caches_with_random, seed_everything)
|
||||
create_kv_caches_with_random)
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
PARTITION_SIZE = 512
|
||||
@ -28,7 +29,7 @@ def main(
|
||||
device: str = "cuda",
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
query = torch.empty(num_seqs,
|
||||
|
||||
@ -3,8 +3,8 @@ import time
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
seed_everything)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -17,7 +17,7 @@ def main(num_tokens: int,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
@ -6,7 +6,8 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
|
||||
get_rope)
|
||||
from vllm.utils import FlexibleArgumentParser, seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def benchmark_rope_kernels_multi_lora(
|
||||
@ -22,7 +23,7 @@ def benchmark_rope_kernels_multi_lora(
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
@ -8,7 +8,7 @@ from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
|
||||
GeluAndMul, NewGELU,
|
||||
QuickGELU, SiluAndMul)
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -37,7 +37,7 @@ def test_act_and_mul(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||
if activation == "silu":
|
||||
@ -85,7 +85,7 @@ def test_activation(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = activation[0]()
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_max_shared_memory_bytes, seed_everything
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -144,7 +144,7 @@ def test_paged_attention(
|
||||
or (version == "rocm" and head_size not in (64, 128))):
|
||||
pytest.skip()
|
||||
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
@ -382,7 +382,7 @@ def test_multi_query_kv_attention(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
|
||||
@ -80,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
zeros_cols = qweight_cols
|
||||
zeros_dtype = torch.int32
|
||||
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
@ -134,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size):
|
||||
qzeros_rows = scales_rows
|
||||
qzeros_cols = qweight_cols
|
||||
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols),
|
||||
dtype=input_dtype,
|
||||
|
||||
@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_max_shared_memory_bytes, seed_everything
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -173,7 +173,7 @@ def test_paged_attention(
|
||||
blocksparse_block_size: int,
|
||||
blocksparse_head_sliding_step: int,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
@ -384,7 +384,7 @@ def test_varlen_blocksparse_attention_prefill(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -56,7 +56,7 @@ def test_copy_blocks(
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
@ -132,7 +132,7 @@ def test_reshape_and_cache(
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
@ -224,7 +224,7 @@ def test_reshape_and_cache_flash(
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Create a random slot mapping.
|
||||
@ -339,7 +339,7 @@ def test_swap_blocks(
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
src_device = device if direction[0] == "cuda" else 'cpu'
|
||||
dst_device = device if direction[1] == "cuda" else 'cpu'
|
||||
@ -408,7 +408,7 @@ def test_fp8_e4m3_conversion(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
low = -224.0
|
||||
high = 224.0
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
@ -70,7 +70,7 @@ def causal_conv1d_update_ref(x,
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the
|
||||
The conv_state will be updated by copying x to the
|
||||
conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
|
||||
@ -161,7 +161,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
x = torch.randn(batch, dim, seqlen, device=device,
|
||||
dtype=itype).contiguous()
|
||||
|
||||
@ -223,7 +223,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
@ -270,7 +270,7 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
@ -343,7 +343,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
seqlens = []
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
@ -91,7 +91,7 @@ def test_flash_attn_with_paged_kv(
|
||||
sliding_window: Optional[int],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
@ -161,7 +161,7 @@ def test_varlen_with_paged_kv(
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
|
||||
@ -4,7 +4,7 @@ import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@ -84,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
@ -170,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
@ -268,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
head_size: int, dtype: torch.dtype, block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
@ -381,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
) -> None:
|
||||
# test doesn't work for num_heads = (16,16)
|
||||
torch.set_default_device("cuda")
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
|
||||
@ -6,7 +6,7 @@ from tests.kernels.quant_utils import (FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant)
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
@ -46,7 +46,7 @@ def opcheck_fp8_quant(output,
|
||||
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, scale_ub: bool,
|
||||
seed: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") + 1e-6 # avoid nans
|
||||
@ -76,7 +76,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
@ -95,7 +95,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
def test_fp8_quant_large(seed: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
|
||||
hidden_size = 1152 # Smallest hidden_size to reproduce the error
|
||||
|
||||
@ -7,7 +7,7 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
|
||||
|
||||
@ -75,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
|
||||
@torch.inference_mode()
|
||||
def test_mmvq(hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
|
||||
@ -111,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype,
|
||||
@torch.inference_mode()
|
||||
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import scaled_int8_quant
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
@ -45,7 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
|
||||
@ -68,7 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
@ -112,7 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
@ -138,7 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float, azp: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
|
||||
@ -3,7 +3,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
@ -31,7 +31,7 @@ def test_rms_norm(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
|
||||
@ -8,7 +8,7 @@ from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def selective_state_update_ref(state,
|
||||
@ -235,7 +235,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 1
|
||||
dim = 4
|
||||
dstate = 8
|
||||
@ -358,7 +358,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
|
||||
@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||
@ -115,7 +114,7 @@ def test_fused_marlin_moe(
|
||||
num_bits: int,
|
||||
is_k_full: bool,
|
||||
):
|
||||
seed_everything(7)
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -48,7 +48,7 @@ def test_rotary_embedding(
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
@ -100,7 +100,7 @@ def test_batched_rotary_embedding(
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
@ -160,7 +160,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||
@ -39,7 +40,7 @@ def test_contexted_kv_attention(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process
|
||||
@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process
|
||||
|
||||
@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import DummyLoRAManager
|
||||
|
||||
@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
seq_len) -> None:
|
||||
dtype = torch.float16
|
||||
seed = 0
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
This script is mainly used to tests various hidden_sizes. We have collected the
|
||||
This script is mainly used to tests various hidden_sizes. We have collected the
|
||||
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
|
||||
whether the corresponding Triton kernel can run normally when tensor parallelism
|
||||
is set to [1, 2, 4, 8, 16, 32, 64].
|
||||
@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||
ref_torch_groupgemm)
|
||||
@ -146,7 +146,7 @@ def test_punica_sgmv(
|
||||
device: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 128
|
||||
(
|
||||
@ -239,7 +239,7 @@ def test_punica_bgmv(
|
||||
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
|
||||
|
||||
torch.set_default_device(device)
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
@ -327,7 +327,7 @@ def test_punica_expand_nslices(
|
||||
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
|
||||
|
||||
torch.set_default_device(device)
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 128 if op_type == "sgmv" else 1
|
||||
(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""
|
||||
This script is mainly used to test whether trtion kernels can run normally
|
||||
under different conditions, including various batches, numbers of LoRA , and
|
||||
This script is mainly used to test whether trtion kernels can run normally
|
||||
under different conditions, including various batches, numbers of LoRA , and
|
||||
maximum ranks.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||
ref_torch_groupgemm)
|
||||
@ -61,7 +61,7 @@ def test_punica_sgmv(
|
||||
device: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 128
|
||||
(
|
||||
@ -154,7 +154,7 @@ def test_punica_bgmv(
|
||||
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
|
||||
|
||||
torch.set_default_device(device)
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
@ -242,7 +242,7 @@ def test_punica_expand_nslices(
|
||||
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
|
||||
|
||||
torch.set_default_device(device)
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 128 if op_type == "sgmv" else 1
|
||||
(
|
||||
|
||||
@ -4,11 +4,10 @@ from typing import Any, Dict, Optional
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
|
||||
def set_weight_attrs(
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import enum
|
||||
import random
|
||||
from typing import NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@ -111,6 +113,18 @@ class Platform:
|
||||
"""
|
||||
return torch.inference_mode(mode=True)
|
||||
|
||||
@classmethod
|
||||
def seed_everything(cls, seed: int) -> None:
|
||||
"""
|
||||
Set the seed of each random module.
|
||||
`torch.manual_seed` will set seed on all devices.
|
||||
|
||||
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@ -7,7 +7,6 @@ import gc
|
||||
import inspect
|
||||
import ipaddress
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
"""
|
||||
Set the seed of each random module.
|
||||
|
||||
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if current_platform.is_xpu():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||
@ -685,7 +668,7 @@ def create_kv_caches_with_random(
|
||||
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||
)
|
||||
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user