mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 06:25:51 +08:00
refine commit, polish PR
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
This commit is contained in:
parent
5a5506c661
commit
e019391cd8
@ -211,6 +211,10 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
assert rel_diff < 0.001
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
current_platform.is_fp8_fnuz(),
|
||||||
|
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"M,N,K,block_size,out_dtype,seed",
|
"M,N,K,block_size,out_dtype,seed",
|
||||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||||
@ -239,13 +243,6 @@ def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
Bs = Bs_fp8.to(torch.float32)
|
Bs = Bs_fp8.to(torch.float32)
|
||||||
|
|
||||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||||
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
|
||||||
|
|
||||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
|
||||||
|
|
||||||
assert As_fp8.shape == (M, (K + 127) // 128), (
|
|
||||||
f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
out = flashinfer_fp8_blockscale_gemm(
|
out = flashinfer_fp8_blockscale_gemm(
|
||||||
input=A_bf16,
|
input=A_bf16,
|
||||||
|
|||||||
@ -168,7 +168,7 @@ if TYPE_CHECKING:
|
|||||||
"relax",
|
"relax",
|
||||||
] = "relax"
|
] = "relax"
|
||||||
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
||||||
VLLM_USE_FLASHINFER_FP8_LINEAR: bool = False
|
VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||||
@ -1211,8 +1211,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
),
|
),
|
||||||
# Allow use of FlashInfer FP8 block-scale GEMM for linear layers.
|
# Allow use of FlashInfer FP8 block-scale GEMM for linear layers.
|
||||||
# This uses TensorRT-LLM kernels and requires SM90+ (Hopper).
|
# This uses TensorRT-LLM kernels and requires SM90+ (Hopper).
|
||||||
"VLLM_USE_FLASHINFER_FP8_LINEAR": lambda: bool(
|
"VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool(
|
||||||
int(os.getenv("VLLM_USE_FLASHINFER_FP8_LINEAR", "0"))
|
int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0"))
|
||||||
),
|
),
|
||||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||||
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
|
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
|
||||||
|
|||||||
@ -38,7 +38,7 @@ from vllm.utils.deep_gemm import (
|
|||||||
from vllm.utils.flashinfer import (
|
from vllm.utils.flashinfer import (
|
||||||
flashinfer_fp8_blockscale_gemm,
|
flashinfer_fp8_blockscale_gemm,
|
||||||
is_flashinfer_fp8_blockscale_gemm_supported,
|
is_flashinfer_fp8_blockscale_gemm_supported,
|
||||||
should_use_flashinfer_for_block_scale_fp8_linear,
|
should_use_flashinfer_for_blockscale_fp8_gemm,
|
||||||
)
|
)
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -238,7 +238,7 @@ def _flashinfer_fp8_blockscale_gemm_impl(
|
|||||||
group_size: int,
|
group_size: int,
|
||||||
use_deep_gemm_e8m0: bool,
|
use_deep_gemm_e8m0: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
def use_flashinfer(
|
def use_flashinfer_deepgemm_swapAB(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
@ -274,11 +274,18 @@ def _flashinfer_fp8_blockscale_gemm_impl(
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
# there is only no benefit of using FlashInfer DeepGEMM for higher batch sizes since
|
||||||
|
# the swapAB optimization is only effective for small batch sizes.
|
||||||
|
# there is slight accuracy loss when using FlashInfer blockscale gemm for all batch
|
||||||
|
# sizes for DeepSeek-V3.
|
||||||
condition = input.shape[0] < 32
|
condition = input.shape[0] < 32
|
||||||
|
|
||||||
# Pass all required variables through operands
|
# torch.cond for torch compile compatibility
|
||||||
return torch.cond(
|
return torch.cond(
|
||||||
condition, use_flashinfer, use_deepgemm, (input, weight, weight_scale)
|
condition,
|
||||||
|
use_flashinfer_deepgemm_swapAB,
|
||||||
|
use_deepgemm,
|
||||||
|
(input, weight, weight_scale),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -357,7 +364,7 @@ class W8A8BlockFp8LinearOp:
|
|||||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
output_dtype = input.dtype
|
output_dtype = input.dtype
|
||||||
|
|
||||||
if should_use_flashinfer_for_block_scale_fp8_linear(
|
if should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||||
self.is_flashinfer_supported, output_dtype, input_2d, weight
|
self.is_flashinfer_supported, output_dtype, input_2d, weight
|
||||||
):
|
):
|
||||||
output = self._run_flashinfer(input_2d, weight, weight_scale)
|
output = self._run_flashinfer(input_2d, weight, weight_scale)
|
||||||
|
|||||||
@ -548,18 +548,23 @@ flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper(
|
|||||||
@functools.cache
|
@functools.cache
|
||||||
def has_flashinfer_fp8_blockscale_gemm() -> bool:
|
def has_flashinfer_fp8_blockscale_gemm() -> bool:
|
||||||
"""Return `True` if FlashInfer block-scale FP8 GEMM is available."""
|
"""Return `True` if FlashInfer block-scale FP8 GEMM is available."""
|
||||||
return has_flashinfer() and hasattr(
|
return (
|
||||||
_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90"
|
has_flashinfer()
|
||||||
|
and current_platform.is_device_capability(90)
|
||||||
|
and hasattr(_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
|
def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
|
||||||
"""Return `True` if FlashInfer block-scale FP8 GEMM is supported."""
|
"""Return `True` if FlashInfer block-scale FP8 GEMM is supported."""
|
||||||
return envs.VLLM_USE_FLASHINFER_FP8_LINEAR and has_flashinfer_fp8_blockscale_gemm()
|
return (
|
||||||
|
envs.VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER
|
||||||
|
and has_flashinfer_fp8_blockscale_gemm()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def should_use_flashinfer_for_block_scale_fp8_linear(
|
def should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||||
is_flashinfer_supported: bool,
|
is_flashinfer_supported: bool,
|
||||||
output_dtype: torch.dtype,
|
output_dtype: torch.dtype,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
@ -612,6 +617,6 @@ __all__ = [
|
|||||||
"flashinfer_scaled_fp4_mm",
|
"flashinfer_scaled_fp4_mm",
|
||||||
"flashinfer_scaled_fp8_mm",
|
"flashinfer_scaled_fp8_mm",
|
||||||
"flashinfer_fp8_blockscale_gemm",
|
"flashinfer_fp8_blockscale_gemm",
|
||||||
"should_use_flashinfer_for_block_scale_fp8_linear",
|
"should_use_flashinfer_for_blockscale_fp8_gemm",
|
||||||
"is_flashinfer_fp8_blockscale_gemm_supported",
|
"is_flashinfer_fp8_blockscale_gemm_supported",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user