mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:45:01 +08:00
[Refactor] Remove duplicate ceil_div (#20023)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
7108934142
commit
879f69bed3
@ -19,7 +19,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser, cdiv
|
||||||
|
|
||||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
@ -117,14 +117,9 @@ def bench_fp8(
|
|||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
def ceil_div(x: int, y: int) -> int:
|
block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
|
||||||
return (x + y - 1) // y
|
|
||||||
|
|
||||||
block_scale_a = torch.rand(
|
|
||||||
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
|
|
||||||
)
|
|
||||||
block_scale_b = torch.rand(
|
block_scale_b = torch.rand(
|
||||||
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
|
cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
|
||||||
)
|
)
|
||||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||||
|
|||||||
@ -7,10 +7,7 @@ from torch import Tensor
|
|||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
def cdiv(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
|
|
||||||
def ref_mla(
|
def ref_mla(
|
||||||
|
|||||||
@ -5,10 +5,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
def cdiv(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("B", [3, 5])
|
@pytest.mark.parametrize("B", [3, 5])
|
||||||
|
|||||||
@ -7,6 +7,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
|
||||||
class BlockDiagonalCausalFromBottomRightMask:
|
class BlockDiagonalCausalFromBottomRightMask:
|
||||||
|
|
||||||
@ -398,11 +400,8 @@ def test_contexted_kv_attention(
|
|||||||
assert (large_tile_size >= B_P_SIZE
|
assert (large_tile_size >= B_P_SIZE
|
||||||
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
|
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
|
||||||
|
|
||||||
def ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
def pad_to_multiple(a, b):
|
def pad_to_multiple(a, b):
|
||||||
return ceil_div(a, b) * b
|
return cdiv(a, b) * b
|
||||||
|
|
||||||
def pad_to_next_power_of_2(a):
|
def pad_to_next_power_of_2(a):
|
||||||
assert a > 0
|
assert a > 0
|
||||||
@ -411,7 +410,7 @@ def test_contexted_kv_attention(
|
|||||||
# calculate input shapes
|
# calculate input shapes
|
||||||
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
|
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
|
||||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||||
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
|
num_active_blocks = cdiv(context_lens, block_size).sum().item()
|
||||||
num_active_blocks = pad_to_multiple(num_active_blocks,
|
num_active_blocks = pad_to_multiple(num_active_blocks,
|
||||||
large_tile_size // block_size)
|
large_tile_size // block_size)
|
||||||
context_kv_len = num_active_blocks * block_size
|
context_kv_len = num_active_blocks * block_size
|
||||||
|
|||||||
@ -8,9 +8,7 @@ import torch
|
|||||||
from neuronxcc import nki
|
from neuronxcc import nki
|
||||||
from neuronxcc.nki.language import par_dim
|
from neuronxcc.nki.language import par_dim
|
||||||
|
|
||||||
|
from vllm.utils import cdiv
|
||||||
def ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
|
|
||||||
def is_power_of_2(x):
|
def is_power_of_2(x):
|
||||||
@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
|
|||||||
(num_tiles, num_blocks_per_tile))
|
(num_tiles, num_blocks_per_tile))
|
||||||
|
|
||||||
block_tables_sbuf = nl.zeros(
|
block_tables_sbuf = nl.zeros(
|
||||||
(ceil_div(num_tiles,
|
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
||||||
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
|
||||||
dtype=nl.int32,
|
dtype=nl.int32,
|
||||||
)
|
)
|
||||||
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
|
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
|
||||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||||
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
||||||
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
||||||
@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
|
|||||||
assert is_power_of_2(
|
assert is_power_of_2(
|
||||||
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
||||||
|
|
||||||
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
|
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
|
||||||
block_tables_transposed = nl.ndarray(
|
block_tables_transposed = nl.ndarray(
|
||||||
(
|
(
|
||||||
num_loads,
|
num_loads,
|
||||||
@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
|
|||||||
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||||
"""
|
"""
|
||||||
# load key cache
|
# load key cache
|
||||||
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
|
||||||
for load_idx in nl.affine_range(num_loads):
|
for load_idx in nl.affine_range(num_loads):
|
||||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||||
@ -605,7 +602,7 @@ def flash_paged_attention(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
||||||
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
|
||||||
cur_k_tile = nl.ndarray(
|
cur_k_tile = nl.ndarray(
|
||||||
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||||
dtype=kernel_dtype,
|
dtype=kernel_dtype,
|
||||||
|
|||||||
@ -6,11 +6,7 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import round_up
|
from vllm.utils import cdiv, round_up
|
||||||
|
|
||||||
|
|
||||||
def ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -115,7 +111,7 @@ def moe_align_block_size_triton(
|
|||||||
cumsum = torch.zeros((num_experts + 1, ),
|
cumsum = torch.zeros((num_experts + 1, ),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
tokens_per_thread = ceil_div(numel, num_experts)
|
tokens_per_thread = cdiv(numel, num_experts)
|
||||||
|
|
||||||
moe_align_block_size_stage1[grid](
|
moe_align_block_size_stage1[grid](
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import cdiv, direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||||
@ -158,12 +158,9 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
if current_platform.has_device_capability(100):
|
if current_platform.has_device_capability(100):
|
||||||
|
|
||||||
def ceil_div(x: int, y: int) -> int:
|
|
||||||
return (x + y - 1) // y
|
|
||||||
|
|
||||||
use_cutlass = cutlass_block_fp8_supported and (
|
use_cutlass = cutlass_block_fp8_supported and (
|
||||||
ceil_div(weight.shape[0], 128) == weight_scale.shape[0]
|
cdiv(weight.shape[0], 128) == weight_scale.shape[0]
|
||||||
and ceil_div(weight.shape[1], 128) == weight_scale.shape[1])
|
and cdiv(weight.shape[1], 128) == weight_scale.shape[1])
|
||||||
else:
|
else:
|
||||||
# TODO: update this after switching to public sm90 block scale gemm
|
# TODO: update this after switching to public sm90 block scale gemm
|
||||||
# as it also supports weight.shape % 128 != 0
|
# as it also supports weight.shape % 128 != 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user