mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 15:25:40 +08:00
[Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
d0e186c16f
commit
b039bfda8f
@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant(
|
|||||||
|
|
||||||
// This kernel currently only supports H % 128 == 0 and assumes a
|
// This kernel currently only supports H % 128 == 0 and assumes a
|
||||||
// fixed GROUP_SIZE of 128.
|
// fixed GROUP_SIZE of 128.
|
||||||
|
static constexpr int GROUP_SIZE = 128;
|
||||||
|
|
||||||
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
||||||
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
|
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
|
||||||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
|
y_q.dtype() == torch::kFloat8_e4m3fnuz);
|
||||||
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(input.size(-1) % 256 == 0);
|
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
|
||||||
|
|
||||||
using Idx_t = int64_t;
|
using Idx_t = int64_t;
|
||||||
|
|
||||||
@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant(
|
|||||||
|
|
||||||
Idx_t stride_counts_e = tokens_per_expert.stride(0);
|
Idx_t stride_counts_e = tokens_per_expert.stride(0);
|
||||||
|
|
||||||
static constexpr int GROUP_SIZE = 128;
|
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
|
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
|
||||||
@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant(
|
|||||||
|
|
||||||
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
|
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
|
||||||
|
|
||||||
|
int const NUM_GROUPS = H / GROUP_SIZE;
|
||||||
if (!use_ue8m0) {
|
if (!use_ue8m0) {
|
||||||
if (H >= 4096) {
|
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
|
||||||
|
/* 8 warps config */
|
||||||
static constexpr int NUM_STAGES = 4;
|
static constexpr int NUM_STAGES = 4;
|
||||||
static constexpr int THREAD_COUNT = 256;
|
static constexpr int THREAD_COUNT = 256;
|
||||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
|
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
|
||||||
} else {
|
} else {
|
||||||
|
/* 1 warp config */
|
||||||
static constexpr int THREAD_COUNT = 32;
|
static constexpr int THREAD_COUNT = 32;
|
||||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
|
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (H >= 4096) {
|
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
|
||||||
|
/* 8 warps config */
|
||||||
static constexpr int NUM_STAGES = 4;
|
static constexpr int NUM_STAGES = 4;
|
||||||
static constexpr int THREAD_COUNT = 256;
|
static constexpr int THREAD_COUNT = 256;
|
||||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
|
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
|
||||||
} else {
|
} else {
|
||||||
|
/* 1 warp config */
|
||||||
static constexpr int THREAD_COUNT = 32;
|
static constexpr int THREAD_COUNT = 32;
|
||||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
|
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,6 +25,7 @@ CASES = [
|
|||||||
(8, 16, 128 * 2, fp8_dtype),
|
(8, 16, 128 * 2, fp8_dtype),
|
||||||
(8, 16, 128 * 3, fp8_dtype),
|
(8, 16, 128 * 3, fp8_dtype),
|
||||||
(8, 64, 7168, fp8_dtype),
|
(8, 64, 7168, fp8_dtype),
|
||||||
|
(8, 128, 128 * 33, fp8_dtype),
|
||||||
(8, 128, 7168, fp8_dtype),
|
(8, 128, 7168, fp8_dtype),
|
||||||
(8, 512, 7168, fp8_dtype),
|
(8, 512, 7168, fp8_dtype),
|
||||||
(8, 1024, 7168, fp8_dtype),
|
(8, 1024, 7168, fp8_dtype),
|
||||||
@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run the SiLU V2 kernel
|
# Run the SiLU V2 kernel
|
||||||
|
# TODO (varun): use_e8m0 is set to false as the reference impl does
|
||||||
|
# not handle that case.
|
||||||
y_q, y_s = persistent_masked_m_silu_mul_quant(
|
y_q, y_s = persistent_masked_m_silu_mul_quant(
|
||||||
y, tokens_per_expert, group_size=group_size
|
y, tokens_per_expert, group_size=group_size, use_ue8m0=False
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|||||||
@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant(
|
|||||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||||
num_parallel_tokens=16,
|
num_parallel_tokens=16,
|
||||||
group_size: int = 128,
|
group_size: int = 128,
|
||||||
|
use_ue8m0: bool | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||||
@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant(
|
|||||||
device=y.device,
|
device=y.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
|
||||||
|
|
||||||
cuda_arch = current_platform.get_device_capability(
|
cuda_arch = current_platform.get_device_capability(
|
||||||
device_id=y.device.index
|
device_id=y.device.index
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user