mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-18 22:17:07 +08:00
Merge e019391cd85bc474d38b4a590a3bbe297cdcddeb into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
184421b3dd
@ -24,6 +24,10 @@ from vllm.utils.deep_gemm import (
|
||||
per_block_cast_to_fp8,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_fp8_blockscale_gemm,
|
||||
has_flashinfer_fp8_blockscale_gemm,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
@ -205,3 +209,50 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
if not has_flashinfer_fp8_blockscale_gemm():
|
||||
pytest.skip(
|
||||
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
|
||||
)
|
||||
# only aligned sizes
|
||||
if K % 128 != 0 or N % 64 != 0:
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = fp8_info.max
|
||||
|
||||
A_bf16 = (torch.rand(M, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||
B_bf16 = (torch.rand(N, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_bf16, block_size[1], use_ue8m0=False)
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_bf16, block_size, use_ue8m0=False)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
out = flashinfer_fp8_blockscale_gemm(
|
||||
input=A_bf16,
|
||||
weight=B_fp8,
|
||||
input_scale=None,
|
||||
weight_scale=Bs,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.bfloat16) - ref_out.to(torch.bfloat16))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.bfloat16)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
@ -168,6 +168,7 @@ if TYPE_CHECKING:
|
||||
"relax",
|
||||
] = "relax"
|
||||
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
||||
VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
@ -1208,6 +1209,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))
|
||||
),
|
||||
# Allow use of FlashInfer FP8 block-scale GEMM for linear layers.
|
||||
# This uses TensorRT-LLM kernels and requires SM90+ (Hopper).
|
||||
"VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool(
|
||||
int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0"))
|
||||
),
|
||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))
|
||||
|
||||
@ -35,6 +35,11 @@ from vllm.utils.deep_gemm import (
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
transform_sf_into_required_layout,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_fp8_blockscale_gemm,
|
||||
is_flashinfer_fp8_blockscale_gemm_supported,
|
||||
should_use_flashinfer_for_blockscale_fp8_gemm,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -226,6 +231,83 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_fp8_blockscale_gemm_impl(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> torch.Tensor:
|
||||
def use_flashinfer_deepgemm_swapAB(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_fp8_blockscale_gemm(
|
||||
input=input,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
def use_deepgemm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input,
|
||||
group_size=group_size,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_deep_gemm_e8m0,
|
||||
)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device,
|
||||
)
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||
)
|
||||
return output
|
||||
|
||||
# there is only no benefit of using FlashInfer DeepGEMM for higher batch sizes since
|
||||
# the swapAB optimization is only effective for small batch sizes.
|
||||
# there is slight accuracy loss when using FlashInfer blockscale gemm for all batch
|
||||
# sizes for DeepSeek-V3.
|
||||
condition = input.shape[0] < 32
|
||||
|
||||
# torch.cond for torch compile compatibility
|
||||
return torch.cond(
|
||||
condition,
|
||||
use_flashinfer_deepgemm_swapAB,
|
||||
use_deepgemm,
|
||||
(input, weight, weight_scale),
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_fp8_blockscale_gemm_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"flashinfer_fp8_blockscale_gemm",
|
||||
_flashinfer_fp8_blockscale_gemm_impl,
|
||||
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
|
||||
)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class W8A8BlockFp8LinearOp:
|
||||
@ -246,6 +328,7 @@ class W8A8BlockFp8LinearOp:
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
||||
@ -281,7 +364,12 @@ class W8A8BlockFp8LinearOp:
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
if should_use_deepgemm_for_fp8_linear(
|
||||
if should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||
self.is_flashinfer_supported, output_dtype, input_2d, weight
|
||||
):
|
||||
output = self._run_flashinfer(input_2d, weight, weight_scale)
|
||||
|
||||
elif should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype, weight, self.is_deep_gemm_supported
|
||||
):
|
||||
output = self._run_deepgemm(input_2d, weight, weight_scale)
|
||||
@ -409,6 +497,29 @@ class W8A8BlockFp8LinearOp:
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_flashinfer(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run FlashInfer FP8 block-scale GEMM.
|
||||
|
||||
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
|
||||
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
|
||||
"""
|
||||
# Now call FlashInfer with BF16 input + FP8 weight, input will be
|
||||
# quantized with FlashInfer kernel (W8A8)
|
||||
output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
|
||||
input=input_2d, # BF16 input
|
||||
weight=weight, # FP8 weight
|
||||
weight_scale=weight_scale, # Weight scales
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
use_deep_gemm_e8m0=self.use_deep_gemm_e8m0,
|
||||
)
|
||||
return output
|
||||
|
||||
def _dispatch_w8a8_blockscale_op(
|
||||
self,
|
||||
use_cutlass: bool,
|
||||
|
||||
@ -540,6 +540,59 @@ def flashinfer_scaled_fp8_mm(
|
||||
return output
|
||||
|
||||
|
||||
flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper(
|
||||
"flashinfer.gemm", "fp8_blockscale_gemm_sm90"
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_fp8_blockscale_gemm() -> bool:
|
||||
"""Return `True` if FlashInfer block-scale FP8 GEMM is available."""
|
||||
return (
|
||||
has_flashinfer()
|
||||
and current_platform.is_device_capability(90)
|
||||
and hasattr(_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90")
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
|
||||
"""Return `True` if FlashInfer block-scale FP8 GEMM is supported."""
|
||||
return (
|
||||
envs.VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER
|
||||
and has_flashinfer_fp8_blockscale_gemm()
|
||||
)
|
||||
|
||||
|
||||
def should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||
is_flashinfer_supported: bool,
|
||||
output_dtype: torch.dtype,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
if not is_flashinfer_supported:
|
||||
return False
|
||||
|
||||
# Verify DeepGEMM N/K dims requirements
|
||||
# NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
|
||||
# test inside kernels/quatization/test_block_fp8.py
|
||||
N_MULTIPLE = 64
|
||||
K_MULTIPLE = 128
|
||||
|
||||
weight_dtype = weight.dtype
|
||||
input_dtype = input.dtype
|
||||
|
||||
should_use_flashinfer = (
|
||||
output_dtype == torch.bfloat16
|
||||
and input_dtype == torch.bfloat16
|
||||
and weight_dtype == torch.float8_e4m3fn
|
||||
and weight.shape[0] % N_MULTIPLE == 0
|
||||
and weight.shape[1] % K_MULTIPLE == 0
|
||||
)
|
||||
|
||||
return should_use_flashinfer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
@ -556,10 +609,14 @@ __all__ = [
|
||||
"has_flashinfer_all2all",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
|
||||
"has_flashinfer_fp8_blockscale_gemm",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
"can_use_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
"flashinfer_scaled_fp8_mm",
|
||||
"flashinfer_fp8_blockscale_gemm",
|
||||
"should_use_flashinfer_for_blockscale_fp8_gemm",
|
||||
"is_flashinfer_fp8_blockscale_gemm_supported",
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user