mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 09:24:27 +08:00
[Feature] Integrate SM100 DeepGEMM support (#20087)
This commit is contained in:
parent
5b032352cc
commit
e2de455c34
@ -86,6 +86,9 @@ def benchmark_config(
|
|||||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||||
)
|
)
|
||||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||||
|
if use_deep_gemm:
|
||||||
|
# we use the default block shape for deepgemm
|
||||||
|
block_quant_shape = [128, 128]
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
if block_quant_shape:
|
if block_quant_shape:
|
||||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||||
|
|||||||
@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
|||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_topk, modular_triton_fused_moe)
|
fused_topk, modular_triton_fused_moe)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import has_deep_gemm
|
||||||
|
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||||
|
|
||||||
dg_available = False
|
dg_available = has_deep_gemm()
|
||||||
try:
|
|
||||||
import deep_gemm
|
if dg_available:
|
||||||
dg_available = True
|
from deep_gemm import get_m_alignment_for_contiguous_layout
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if current_platform.get_device_capability() < (9, 0):
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||||
@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
|||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||||
|
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||||
monkeypatch):
|
monkeypatch):
|
||||||
@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
|||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||||
|
block_m = get_m_alignment_for_contiguous_layout()
|
||||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
|
||||||
block_size = [block_m, block_m]
|
block_size = [block_m, block_m]
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||||
|
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||||
|
|
||||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||||
from .utils import make_test_weights
|
from .utils import make_test_weights
|
||||||
@ -368,6 +369,8 @@ NUM_EXPERTS = [32]
|
|||||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
@requires_deep_gemm
|
@requires_deep_gemm
|
||||||
|
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
|
||||||
|
reason="Skipping test for Blackwell DeepGEMM")
|
||||||
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||||
topk: int, world_dp_size: tuple[int, int]):
|
topk: int, world_dp_size: tuple[int, int]):
|
||||||
"""
|
"""
|
||||||
@ -423,6 +426,8 @@ USE_FP8_DISPATCH = [False]
|
|||||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
@requires_deep_gemm
|
@requires_deep_gemm
|
||||||
|
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
|
||||||
|
reason="Skipping test for Blackwell DeepGEMM")
|
||||||
def test_ll_deepep_deepgemm_moe(
|
def test_ll_deepep_deepgemm_moe(
|
||||||
mnk: tuple[int, int, int],
|
mnk: tuple[int, int, int],
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
|
|||||||
@ -13,48 +13,18 @@ import torch
|
|||||||
|
|
||||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.utils import has_deep_gemm
|
||||||
per_token_group_quant_fp8)
|
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
|
||||||
from vllm.utils import cdiv
|
per_token_group_cast_to_fp8)
|
||||||
|
|
||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
BLOCK_SIZE = [128, 128]
|
||||||
|
|
||||||
if has_deep_gemm:
|
|
||||||
import deep_gemm
|
|
||||||
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
|
|
||||||
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
|
|
||||||
|
|
||||||
requires_deep_gemm = pytest.mark.skipif(
|
requires_deep_gemm = pytest.mark.skipif(
|
||||||
not has_deep_gemm,
|
not has_deep_gemm(),
|
||||||
reason="Requires deep_gemm kernels",
|
reason="Requires deep_gemm kernels",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|
||||||
x, y = x.double(), y.double()
|
|
||||||
denominator = (x * x + y * y).sum()
|
|
||||||
sim = 2 * (x * y).sum() / denominator
|
|
||||||
return 1 - sim
|
|
||||||
|
|
||||||
|
|
||||||
def per_block_cast_to_fp8(
|
|
||||||
x: torch.Tensor,
|
|
||||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.dim() == 2
|
|
||||||
m, n = x.shape
|
|
||||||
x_padded = torch.zeros(
|
|
||||||
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
|
|
||||||
dtype=x.dtype,
|
|
||||||
device=x.device)
|
|
||||||
x_padded[:m, :n] = x
|
|
||||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
|
||||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
|
||||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
|
||||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
|
||||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
|
||||||
return x_scaled_sub, scales
|
|
||||||
|
|
||||||
|
|
||||||
def make_block_quant_fp8_weights(
|
def make_block_quant_fp8_weights(
|
||||||
e: int,
|
e: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
|||||||
"""
|
"""
|
||||||
tokens_bf16 = torch.randn(
|
tokens_bf16 = torch.randn(
|
||||||
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
|
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
|
||||||
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
|
||||||
|
|
||||||
# expert weight tensors
|
# expert weight tensors
|
||||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
|
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
|
||||||
@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
|||||||
block_shape=block_size,
|
block_shape=block_size,
|
||||||
allow_deep_gemm=True,
|
allow_deep_gemm=True,
|
||||||
)
|
)
|
||||||
|
diff = calc_diff(out_deepgemm, out_triton)
|
||||||
base = out_triton.abs().mean()
|
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||||
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
|
|
||||||
rtol = 0.05
|
|
||||||
# ----- Compare -----
|
|
||||||
torch.testing.assert_close(
|
|
||||||
out_deepgemm.to(torch.float32),
|
|
||||||
out_triton.to(torch.float32),
|
|
||||||
rtol=rtol,
|
|
||||||
atol=float(atol),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
# Note: W1 has shape (E, 2N, K), so N = 512
|
||||||
|
|||||||
@ -8,19 +8,15 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||||
native_w8a8_block_matmul,
|
native_w8a8_block_matmul)
|
||||||
per_block_cast_to_fp8)
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
|
||||||
|
w8a8_block_fp8_matmul)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import has_deep_gemm
|
||||||
dg_available = False
|
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
|
||||||
try:
|
per_token_group_cast_to_fp8)
|
||||||
import deep_gemm
|
|
||||||
dg_available = True
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if current_platform.get_device_capability() < (9, 0):
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||||
@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"M,N,K,block_size,out_dtype,seed",
|
"M,N,K,block_size,out_dtype,seed",
|
||||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
@pytest.mark.skipif(not has_deep_gemm(),
|
||||||
|
reason="DeepGemm kernels not available.")
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||||
# only aligned sizes
|
# only aligned sizes
|
||||||
@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
|
|
||||||
_, block_k = block_size[0], block_size[1]
|
A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
|
||||||
|
|
||||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
|
|
||||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
||||||
|
|
||||||
As = As_fp8.to(torch.float32)
|
As = As_fp8.to(torch.float32)
|
||||||
@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
out_dtype)
|
out_dtype)
|
||||||
|
|
||||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||||
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
|
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
||||||
|
|
||||||
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
|
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
|
||||||
|
|
||||||
assert As_fp8.shape == (M, (K + 127) //
|
assert As_fp8.shape == (M, (K + 127) //
|
||||||
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||||
|
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||||
|
|
||||||
rel_diff = (torch.mean(
|
rel_diff = (torch.mean(
|
||||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -271,7 +272,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
assert expert_tokens_meta is not None
|
assert expert_tokens_meta is not None
|
||||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||||
|
|
||||||
import deep_gemm as dg
|
|
||||||
assert hidden_states.ndim == 3
|
assert hidden_states.ndim == 3
|
||||||
assert self.block_shape is not None
|
assert self.block_shape is not None
|
||||||
|
|
||||||
@ -289,18 +289,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# for the M expectation of each batch, correctly setting this value
|
# for the M expectation of each batch, correctly setting this value
|
||||||
# may lead to better performance.
|
# may lead to better performance.
|
||||||
expected_m = max_num_tokens
|
expected_m = max_num_tokens
|
||||||
|
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale),
|
out=workspace1,
|
||||||
(w1, w1_scale),
|
masked_m=expert_num_tokens,
|
||||||
out=workspace1,
|
expected_m=expected_m)
|
||||||
masked_m=expert_num_tokens,
|
|
||||||
expected_m=expected_m)
|
|
||||||
|
|
||||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
||||||
expert_num_tokens)
|
expert_num_tokens)
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
|
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
|
||||||
(w2, w2_scale),
|
out=output,
|
||||||
out=output,
|
masked_m=expert_num_tokens,
|
||||||
masked_m=expert_num_tokens,
|
expected_m=expected_m)
|
||||||
expected_m=expected_m)
|
|
||||||
|
|||||||
@ -14,9 +14,10 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|||||||
MoEPrepareAndFinalizeNoEP)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
_resize_cache, per_token_group_quant_fp8)
|
|
||||||
from vllm.utils import has_deep_gemm, round_up
|
from vllm.utils import has_deep_gemm, round_up
|
||||||
|
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
|
||||||
|
per_token_group_cast_to_fp8)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -127,7 +128,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
):
|
):
|
||||||
import deep_gemm as dg
|
|
||||||
assert self.block_shape is not None
|
assert self.block_shape is not None
|
||||||
|
|
||||||
a1q = hidden_states
|
a1q = hidden_states
|
||||||
@ -164,19 +164,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
(M_sum, N // 2))
|
(M_sum, N // 2))
|
||||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
|
||||||
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
|
mm1_out, expert_ids)
|
||||||
|
|
||||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||||
|
|
||||||
a2q_scale: Optional[torch.Tensor] = None
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
|
a2q, a2q_scale = per_token_group_cast_to_fp8(act_out,
|
||||||
self.block_shape[1],
|
self.block_shape[1],
|
||||||
column_major_scales=True,
|
column_major_scales=True,
|
||||||
out_q=quant_out)
|
out_q=quant_out)
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
|
||||||
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
|
mm2_out, expert_ids)
|
||||||
|
|
||||||
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
|
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
|
||||||
|
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||||
|
|
||||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||||
|
|
||||||
@ -1171,9 +1172,15 @@ def fused_experts(
|
|||||||
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
|
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
|
||||||
# For now, disable DeepGemm for small N (<= 512) until better
|
# For now, disable DeepGemm for small N (<= 512) until better
|
||||||
# permute/unpermute ops are available.
|
# permute/unpermute ops are available.
|
||||||
|
# However, on B200, we use DeepGemm for all cases becuase they only support
|
||||||
|
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||||
|
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||||
|
# accuracy issue.
|
||||||
N = w1.size(1)
|
N = w1.size(1)
|
||||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
should_use_deep_gemm = ((N > 512
|
||||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||||
|
or is_blackwell_deep_gemm_used())
|
||||||
|
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
||||||
assert apply_router_weight_on_input is False
|
assert apply_router_weight_on_input is False
|
||||||
return deep_gemm_moe_fp8(
|
return deep_gemm_moe_fp8(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@ -1363,7 +1370,6 @@ def fused_experts_impl(
|
|||||||
|
|
||||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|
||||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||||
A=curr_hidden_states,
|
A=curr_hidden_states,
|
||||||
A_scale=a1_scale,
|
A_scale=a1_scale,
|
||||||
|
|||||||
@ -48,7 +48,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
assert topk == 1, \
|
assert topk == 1, \
|
||||||
"apply_router_weight_on_input is only implemented for topk=1"
|
"apply_router_weight_on_input is only implemented for topk=1"
|
||||||
a1.mul_(topk_weights.to(a1.dtype))
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1, a1_scale, quant_config.quant_dtype,
|
a1, a1_scale, quant_config.quant_dtype,
|
||||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
|
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||||
|
|
||||||
|
|
||||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
@ -102,7 +103,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||||
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K)
|
||||||
|
or is_blackwell_deep_gemm_used()):
|
||||||
assert self.deep_gemm_expert is not None
|
assert self.deep_gemm_expert is not None
|
||||||
return self.deep_gemm_expert.workspace_shapes(
|
return self.deep_gemm_expert.workspace_shapes(
|
||||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||||
@ -132,7 +134,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
):
|
):
|
||||||
use_deep_gemm = (self.allow_deep_gemm
|
use_deep_gemm = (self.allow_deep_gemm
|
||||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
and (_valid_deep_gemm(hidden_states, w1, w2)
|
||||||
|
or is_blackwell_deep_gemm_used()))
|
||||||
|
|
||||||
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
||||||
assert experts is not None
|
assert experts is not None
|
||||||
|
|||||||
@ -15,6 +15,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
|
||||||
|
per_token_group_cast_to_fp8)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -115,7 +117,10 @@ def _fp8_quantize(
|
|||||||
assert not per_act_token
|
assert not per_act_token
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
if is_blackwell_deep_gemm_used():
|
||||||
|
A, A_scale = per_token_group_cast_to_fp8(A, block_k)
|
||||||
|
else:
|
||||||
|
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||||
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
||||||
|
|
||||||
return A, A_scale
|
return A, A_scale
|
||||||
|
|||||||
@ -8,10 +8,9 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
@triton.jit()
|
@triton.jit()
|
||||||
|
|||||||
@ -6,10 +6,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils import direct_register_custom_op, has_deep_gemm
|
from vllm.utils import direct_register_custom_op
|
||||||
|
from vllm.utils.deep_gemm import fp8_gemm_nt
|
||||||
if has_deep_gemm():
|
|
||||||
import deep_gemm
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -57,7 +55,7 @@ def w8a8_block_fp8_matmul_deepgemm(
|
|||||||
output_dtype)
|
output_dtype)
|
||||||
# Deepgemm only supports output tensor type as bfloat16
|
# Deepgemm only supports output tensor type as bfloat16
|
||||||
assert C.dtype == torch.bfloat16
|
assert C.dtype == torch.bfloat16
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
fp8_gemm_nt((A, As), (B, Bs), C)
|
||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||||
prepare_moe_fp8_layer_for_marlin)
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
@ -40,6 +42,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
|
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
@ -393,6 +396,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Activations not quantized for marlin.
|
# Activations not quantized for marlin.
|
||||||
del layer.input_scale
|
del layer.input_scale
|
||||||
|
|
||||||
|
# On B200, DeepGemm only support E8M0 scale, which means we need to
|
||||||
|
# requantize the weight and input to the specific scale
|
||||||
|
# at the same time.
|
||||||
|
if is_blackwell_deep_gemm_used():
|
||||||
|
assert layer.weight_block_size is not None
|
||||||
|
block_sz = tuple(layer.weight_block_size)
|
||||||
|
requant_weight_ue8m0_inplace(
|
||||||
|
layer.weight.data,
|
||||||
|
layer.weight_scale_inv.data if hasattr(
|
||||||
|
layer, "weight_scale_inv") else layer.weight_scale.data,
|
||||||
|
block_sz,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -670,15 +686,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||||
# it ahead of time for performance reasons.
|
# it ahead of time for performance reasons.
|
||||||
if self.allow_deep_gemm:
|
if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
|
||||||
# Lazy import to avoid CUDA initialization problems.
|
# Lazy import to avoid CUDA initialization problems.
|
||||||
import deep_gemm as dg
|
|
||||||
if _is_col_major(layer.w13_weight_scale_inv):
|
if _is_col_major(layer.w13_weight_scale_inv):
|
||||||
layer.w13_weight_scale_inv = \
|
layer.w13_weight_scale_inv = \
|
||||||
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
|
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
|
||||||
if _is_col_major(layer.w2_weight_scale_inv):
|
if _is_col_major(layer.w2_weight_scale_inv):
|
||||||
layer.w2_weight_scale_inv = \
|
layer.w2_weight_scale_inv = \
|
||||||
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
@ -797,6 +812,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
|
if is_blackwell_deep_gemm_used():
|
||||||
|
assert layer.weight_block_size is not None
|
||||||
|
# Re-quantise the expert weights so their scales are UE8M0.
|
||||||
|
block_sz = tuple(layer.weight_block_size)
|
||||||
|
requant_weight_ue8m0_inplace(
|
||||||
|
layer.w13_weight.data,
|
||||||
|
layer.w13_weight_scale_inv.data,
|
||||||
|
block_sz,
|
||||||
|
)
|
||||||
|
requant_weight_ue8m0_inplace(
|
||||||
|
layer.w2_weight.data,
|
||||||
|
layer.w2_weight_scale_inv.data,
|
||||||
|
block_sz,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||||
|
if _is_col_major(layer.w13_weight_scale_inv):
|
||||||
|
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||||
|
layer.w13_weight_scale_inv).contiguous()
|
||||||
|
if _is_col_major(layer.w2_weight_scale_inv):
|
||||||
|
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||||
|
layer.w2_weight_scale_inv).contiguous()
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -13,7 +14,7 @@ import vllm.envs as envs
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
scaled_dequantize)
|
group_broadcast)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -235,7 +236,7 @@ def block_quant_to_tensor_quant(
|
|||||||
The outputs are tensor-wise quantization tensor and tensor-wise
|
The outputs are tensor-wise quantization tensor and tensor-wise
|
||||||
quantization scale. Note only float8 is supported for now.
|
quantization scale. Note only float8 is supported for now.
|
||||||
"""
|
"""
|
||||||
x_dq_block = scaled_dequantize(x_q_block, x_s)
|
x_dq_block = group_broadcast(x_q_block, x_s)
|
||||||
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||||
return x_q_tensor, scale
|
return x_q_tensor, scale
|
||||||
|
|
||||||
@ -651,3 +652,124 @@ def w8a8_block_fp8_matmul(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
||||||
|
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
||||||
|
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||||
|
"""
|
||||||
|
Global memory address of TMA must be 16-byte aligned.
|
||||||
|
Since we use column-major layout for the LHS scaling tensor,
|
||||||
|
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
|
||||||
|
16 bytes.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: original M-axis shape of the LHS scaling tensor.
|
||||||
|
element_size: element size of the LHS scaling tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
M-axis shape of the LHS scaling tensor after padding.
|
||||||
|
"""
|
||||||
|
tma_alignment_bytes = 16
|
||||||
|
assert tma_alignment_bytes % element_size == 0
|
||||||
|
alignment = tma_alignment_bytes // element_size
|
||||||
|
return cdiv(x, alignment) * alignment
|
||||||
|
|
||||||
|
|
||||||
|
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
||||||
|
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
||||||
|
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
|
||||||
|
will be called if necessary.
|
||||||
|
If the input tensor is already column-major layout and 16-byte aligned along
|
||||||
|
the M axis (thus meets the requirement of LHS scaling tensor in
|
||||||
|
DeepGEMM), this function will do nothing.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: usually the LHS scaling tensor in GEMM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The LHS scaling tensor of TMA-aligned transposed format.
|
||||||
|
"""
|
||||||
|
# NOTES: for the extreme performance, you may rewrite/fuse this function in
|
||||||
|
# CUDA
|
||||||
|
assert x.dim() in (2, 3)
|
||||||
|
remove_dim = False
|
||||||
|
m, n = x.shape[-2], x.shape[-1]
|
||||||
|
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||||
|
if x.dim() == 2:
|
||||||
|
if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
||||||
|
return x
|
||||||
|
x, remove_dim = x.unsqueeze(0), True
|
||||||
|
|
||||||
|
b = x.shape[0]
|
||||||
|
|
||||||
|
# The last kernel gives a column-major TMA aligned layout
|
||||||
|
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
|
||||||
|
2) == aligned_m:
|
||||||
|
return x.squeeze(0) if remove_dim else x
|
||||||
|
|
||||||
|
# Normal layout requires transposing
|
||||||
|
aligned_x = torch.transpose(
|
||||||
|
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||||
|
aligned_x[:, :m, :] = x
|
||||||
|
aligned_x = aligned_x[:, :m, :]
|
||||||
|
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||||
|
|
||||||
|
|
||||||
|
def requant_weight_ue8m0_inplace(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
block_size: Sequence[int] = (128, 128),
|
||||||
|
) -> None:
|
||||||
|
"""Re-quantise *weight* so that its per-block scaling factors are in the
|
||||||
|
UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``.
|
||||||
|
Expected shape ``(..., M, K)``.
|
||||||
|
weight_scale: Corresponding per-block scale tensor (``torch.float32``)
|
||||||
|
with shape ``(..., M // block_size[0], K // block_size[1])``.
|
||||||
|
block_size: 2-element iterable ``[block_m, block_k]`` describing the
|
||||||
|
block quantisation granularity.
|
||||||
|
"""
|
||||||
|
if weight.numel() == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if weight.dtype != torch.float8_e4m3fn:
|
||||||
|
raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got "
|
||||||
|
f"{weight.dtype} instead.")
|
||||||
|
|
||||||
|
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||||
|
|
||||||
|
block_m, block_k = int(block_size[0]), int(block_size[1])
|
||||||
|
|
||||||
|
# Flatten leading dimensions so we can iterate over the last two dims.
|
||||||
|
leading_shape = weight.shape[:-2]
|
||||||
|
if len(leading_shape) == 0:
|
||||||
|
w_view = weight.unsqueeze(0)
|
||||||
|
s_view = weight_scale.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
w_view = weight.reshape(-1, weight.shape[-2], weight.shape[-1])
|
||||||
|
s_view = weight_scale.reshape(-1, *weight_scale.shape[-2:])
|
||||||
|
|
||||||
|
num_mats = w_view.size(0)
|
||||||
|
for idx in range(num_mats):
|
||||||
|
w_q = w_view[idx]
|
||||||
|
s_old = s_view[idx]
|
||||||
|
|
||||||
|
# De-quantise with the *old* scaling factors (float32).
|
||||||
|
m_cur, k_cur = w_q.shape
|
||||||
|
s_float = s_old.to(torch.float32)
|
||||||
|
# Expand scales along rows and cols by block size, then crop.
|
||||||
|
s_exp_r = torch.repeat_interleave(s_float, block_m, dim=0)
|
||||||
|
s_exp = torch.repeat_interleave(s_exp_r, block_k, dim=1)
|
||||||
|
s_exp = s_exp[:m_cur, :k_cur]
|
||||||
|
w_dq = w_q.to(torch.float32) * s_exp
|
||||||
|
# Re-quantise using power-of-two scaling (UE8M0).
|
||||||
|
w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k])
|
||||||
|
|
||||||
|
# Write back the results in-place.
|
||||||
|
w_q.copy_(w_requant)
|
||||||
|
s_old.copy_(s_requant)
|
||||||
|
|||||||
152
vllm/utils/deep_gemm.py
Normal file
152
vllm/utils/deep_gemm.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Compatibility wrapper for DeepGEMM API changes.
|
||||||
|
|
||||||
|
Users of vLLM should always import **only** these wrappers.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
from typing import Any, Callable, NoReturn
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.utils import cuda_get_device_properties, has_deep_gemm
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def is_blackwell_deep_gemm_used() -> bool:
|
||||||
|
"""Return ``True`` if vLLM is configured to use DeepGEMM on a
|
||||||
|
Blackwell-class GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()
|
||||||
|
and _per_block_cast_impl is not None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return cuda_get_device_properties(0, ("major", ))[0] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||||
|
"""Placeholder for unavailable DeepGEMM backend."""
|
||||||
|
raise RuntimeError(
|
||||||
|
"DeepGEMM backend is not available. Please install the `deep_gemm` "
|
||||||
|
"package to enable FP8 kernels.")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
|
||||||
|
"""Return the *new* symbol if it exists, otherwise the *old* one."""
|
||||||
|
if hasattr(module, new):
|
||||||
|
return getattr(module, new)
|
||||||
|
if hasattr(module, old):
|
||||||
|
return getattr(module, old)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if not has_deep_gemm():
|
||||||
|
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||||
|
_grouped_impl: Callable[..., Any] | None = None
|
||||||
|
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||||
|
_per_token_cast_impl: Callable[..., Any] | None = None
|
||||||
|
_per_block_cast_impl: Callable[..., Any] | None = None
|
||||||
|
else:
|
||||||
|
_dg = importlib.import_module("deep_gemm") # type: ignore
|
||||||
|
|
||||||
|
_fp8_gemm_nt_impl = _resolve_symbol(
|
||||||
|
_dg,
|
||||||
|
"fp8_gemm_nt",
|
||||||
|
"gemm_fp8_fp8_bf16_nt",
|
||||||
|
)
|
||||||
|
_grouped_impl = _resolve_symbol(
|
||||||
|
_dg,
|
||||||
|
"m_grouped_fp8_gemm_nt_contiguous",
|
||||||
|
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
||||||
|
)
|
||||||
|
_grouped_masked_impl = _resolve_symbol(
|
||||||
|
_dg,
|
||||||
|
"fp8_m_grouped_gemm_nt_masked",
|
||||||
|
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
|
||||||
|
try:
|
||||||
|
_math_mod = importlib.import_module(
|
||||||
|
"deep_gemm.utils.math") # type: ignore
|
||||||
|
_per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8",
|
||||||
|
None)
|
||||||
|
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
|
||||||
|
None)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_per_token_cast_impl = None
|
||||||
|
_per_block_cast_impl = None
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_gemm_nt(*args, **kwargs):
|
||||||
|
if _fp8_gemm_nt_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
return _fp8_gemm_nt_impl(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||||
|
if _grouped_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
return _grouped_impl(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||||
|
if _grouped_masked_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
return _grouped_masked_impl(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs):
|
||||||
|
"""Wrapper for token-wise FP8 quantisation.
|
||||||
|
|
||||||
|
• If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
|
||||||
|
• Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
|
||||||
|
"""
|
||||||
|
|
||||||
|
if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used():
|
||||||
|
assert group_size == 128, "group_size must be 128 for deepgemm"
|
||||||
|
return _per_token_cast_impl(x)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8 as _ptg)
|
||||||
|
return _ptg(x, group_size, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_cast_to_fp8(x, *args, **kwargs):
|
||||||
|
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
|
||||||
|
return _per_block_cast_impl(x)
|
||||||
|
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
|
||||||
|
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
|
||||||
|
return _pbcf(x, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||||
|
"""Return a global difference metric for unit tests.
|
||||||
|
|
||||||
|
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
|
||||||
|
error, causing ``torch.testing.assert_close`` to fail. Instead of checking
|
||||||
|
every element, we compute a cosine-style similarity over the whole tensor
|
||||||
|
and report ``1 - sim``. Once kernel accuracy improves this helper can be
|
||||||
|
removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
x, y = x.double(), y.double()
|
||||||
|
denominator = (x * x + y * y).sum()
|
||||||
|
sim = 2 * (x * y).sum() / denominator
|
||||||
|
return 1 - sim
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"calc_diff",
|
||||||
|
"fp8_gemm_nt",
|
||||||
|
"m_grouped_fp8_gemm_nt_contiguous",
|
||||||
|
"fp8_m_grouped_gemm_nt_masked",
|
||||||
|
"per_token_group_cast_to_fp8",
|
||||||
|
"per_block_cast_to_fp8",
|
||||||
|
"is_blackwell_deep_gemm_used",
|
||||||
|
]
|
||||||
Loading…
x
Reference in New Issue
Block a user