mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
[Refactor] Remove duplicate ceil_div (#20023)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
7108934142
commit
879f69bed3
@ -19,7 +19,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, cdiv
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
@ -117,14 +117,9 @@ def bench_fp8(
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
block_scale_a = torch.rand(
|
||||
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
|
||||
)
|
||||
block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
|
||||
block_scale_b = torch.rand(
|
||||
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
|
||||
cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
|
||||
)
|
||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||
|
||||
@ -7,10 +7,7 @@ from torch import Tensor
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def ref_mla(
|
||||
|
||||
@ -5,10 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B", [3, 5])
|
||||
|
||||
@ -7,6 +7,8 @@ import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
class BlockDiagonalCausalFromBottomRightMask:
|
||||
|
||||
@ -398,11 +400,8 @@ def test_contexted_kv_attention(
|
||||
assert (large_tile_size >= B_P_SIZE
|
||||
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
def pad_to_multiple(a, b):
|
||||
return ceil_div(a, b) * b
|
||||
return cdiv(a, b) * b
|
||||
|
||||
def pad_to_next_power_of_2(a):
|
||||
assert a > 0
|
||||
@ -411,7 +410,7 @@ def test_contexted_kv_attention(
|
||||
# calculate input shapes
|
||||
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
|
||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
|
||||
num_active_blocks = cdiv(context_lens, block_size).sum().item()
|
||||
num_active_blocks = pad_to_multiple(num_active_blocks,
|
||||
large_tile_size // block_size)
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
|
||||
@ -8,9 +8,7 @@ import torch
|
||||
from neuronxcc import nki
|
||||
from neuronxcc.nki.language import par_dim
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def is_power_of_2(x):
|
||||
@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
|
||||
(num_tiles, num_blocks_per_tile))
|
||||
|
||||
block_tables_sbuf = nl.zeros(
|
||||
(ceil_div(num_tiles,
|
||||
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
||||
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
|
||||
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
||||
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
||||
@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
||||
|
||||
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
|
||||
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
|
||||
block_tables_transposed = nl.ndarray(
|
||||
(
|
||||
num_loads,
|
||||
@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
|
||||
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
"""
|
||||
# load key cache
|
||||
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
||||
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
@ -605,7 +602,7 @@ def flash_paged_attention(
|
||||
)
|
||||
|
||||
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
||||
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
||||
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
|
||||
cur_k_tile = nl.ndarray(
|
||||
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype,
|
||||
|
||||
@ -6,11 +6,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
from vllm.utils import cdiv, round_up
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -115,7 +111,7 @@ def moe_align_block_size_triton(
|
||||
cumsum = torch.zeros((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
tokens_per_thread = ceil_div(numel, num_experts)
|
||||
tokens_per_thread = cdiv(numel, num_experts)
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
@ -158,12 +158,9 @@ def apply_w8a8_block_fp8_linear(
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.has_device_capability(100):
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
use_cutlass = cutlass_block_fp8_supported and (
|
||||
ceil_div(weight.shape[0], 128) == weight_scale.shape[0]
|
||||
and ceil_div(weight.shape[1], 128) == weight_scale.shape[1])
|
||||
cdiv(weight.shape[0], 128) == weight_scale.shape[0]
|
||||
and cdiv(weight.shape[1], 128) == weight_scale.shape[1])
|
||||
else:
|
||||
# TODO: update this after switching to public sm90 block scale gemm
|
||||
# as it also supports weight.shape % 128 != 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user