mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[Kernel] Add nvfp4 gemm flashinfer backends (#22346)
Signed-off-by: Julien Lin <jullin@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
b8ff05361a
commit
279a5f31b3
@ -669,6 +669,7 @@ steps:
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
|
||||
139
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Normal file
139
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Normal file
@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
|
||||
convert_swizzled_to_linear, dequantize_nvfp4_to_dtype)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_sf,
|
||||
b_sf,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
):
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert m_k == n_k
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_sf,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
|
||||
b_sf,
|
||||
b_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("backend", ["cutlass", "trtllm"])
|
||||
@pytest.mark.parametrize("autotune", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_nvfp4_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
backend: str,
|
||||
autotune: bool,
|
||||
) -> None:
|
||||
if backend == "trtllm" and dtype == torch.float16:
|
||||
pytest.skip(
|
||||
"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, packed_k = shape
|
||||
k = packed_k * 2
|
||||
block_size = 16
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||
# from checkpoints are in linear scales.
|
||||
# So instead of needing to swizzle for cutlass as in modelopt.py,
|
||||
# we need to unswizzle for trtllm here.
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
# get_ref_results unswizzles the scales internally.
|
||||
expected_out = get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
)
|
||||
|
||||
import flashinfer
|
||||
|
||||
if backend == "trtllm":
|
||||
epilogue_tile_m = 128
|
||||
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8),
|
||||
epilogue_tile_m)
|
||||
|
||||
b_scale_interleaved = convert_swizzled_to_linear(
|
||||
b_scale_interleaved, n, k, block_size)
|
||||
b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a(
|
||||
b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape(
|
||||
b_scale_interleaved.shape).view(torch.float8_e4m3fn))
|
||||
|
||||
with flashinfer.autotune(autotune):
|
||||
out = flashinfer_scaled_fp4_mm(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
alpha,
|
||||
dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out,
|
||||
expected_out.to(dtype=dtype),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
@ -65,9 +65,12 @@ def test_nvfp4_gemm(
|
||||
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
alpha = 1. / (a_global_scale * b_global_scale)
|
||||
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||
# from checkpoints are in linear scales.
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
# get_ref_results unswizzles the scales internally.
|
||||
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
|
||||
b_scale_interleaved, a_global_scale,
|
||||
b_global_scale, m, n, dtype, block_size,
|
||||
|
||||
@ -1101,6 +1101,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_TRTLLM_ATTENTION":
|
||||
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
||||
|
||||
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
|
||||
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
|
||||
# vllm cutlass GEMM, marlin GEMM.
|
||||
"VLLM_USE_TRTLLM_FP4_GEMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))),
|
||||
|
||||
# Controls garbage collection during CUDA graph capture.
|
||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||
# If set to 1, allows GC to run during capture.
|
||||
@ -1208,6 +1214,7 @@ def compute_hash() -> str:
|
||||
"VLLM_DP_SIZE",
|
||||
"VLLM_USE_STANDALONE_COMPILE",
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
||||
"VLLM_USE_TRTLLM_FP4_GEMM",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
if key in environment_variables:
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -24,6 +25,13 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self):
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
else:
|
||||
self.backend = "cutlass"
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
@ -108,16 +116,36 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
layer.weight_global_scale.max().to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
# shuffles ourselves.
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
# required by cutlass kernel; need Parameter, not ModelWeightParameter
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
weight = layer.weight_packed.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.alpha = Parameter(layer.input_global_scale *
|
||||
layer.weight_global_scale,
|
||||
requires_grad=False)
|
||||
epilogue_tile_m = 128
|
||||
weight = shuffle_matrix_a(weight.view(torch.uint8),
|
||||
epilogue_tile_m)
|
||||
weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
|
||||
torch.uint8), epilogue_tile_m).reshape(
|
||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||
requires_grad=False)
|
||||
|
||||
layer.alpha = Parameter(
|
||||
1 / (layer.input_global_scale * layer.weight_global_scale),
|
||||
requires_grad=False)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
@ -128,7 +156,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
out = run_nvfp4_emulations(
|
||||
x=x,
|
||||
input_global_scale=layer.input_global_scale,
|
||||
weight=layer.weight,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale_swizzled=layer.weight_scale_swizzled,
|
||||
weight_global_scale=layer.weight_global_scale)
|
||||
if bias is not None:
|
||||
@ -136,14 +164,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
return out
|
||||
|
||||
output_dtype = x.dtype
|
||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||
|
||||
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
||||
layer.weight_scale_swizzled,
|
||||
1 / layer.alpha, output_dtype)
|
||||
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
|
||||
layer.weight_scale_swizzled, layer.alpha, output_dtype)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
else:
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
@ -38,7 +38,8 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import next_power_of_2
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer,
|
||||
has_flashinfer_moe)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -724,16 +725,20 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -815,17 +820,38 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
# block_size = 16;
|
||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Block scale must be represented as FP8-E4M3")
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
# shuffles ourselves.
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
del layer.weight_scale_swizzled
|
||||
weight = layer.weight.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
epilogue_tile_m = 128
|
||||
weight = shuffle_matrix_a(weight.view(torch.uint8),
|
||||
epilogue_tile_m)
|
||||
weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
|
||||
torch.uint8), epilogue_tile_m).reshape(
|
||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
if self.backend == "marlin":
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
del layer.weight_scale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -833,7 +859,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_marlin:
|
||||
if self.backend == "marlin":
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
@ -859,9 +885,21 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
|
||||
assert (layer.alpha.dtype == torch.float32)
|
||||
|
||||
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
||||
layer.weight_scale_swizzled, layer.alpha,
|
||||
output_dtype)
|
||||
mm_args = (
|
||||
x_fp4,
|
||||
layer.weight,
|
||||
x_blockscale,
|
||||
layer.weight_scale_swizzled,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
else:
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
@ -5,16 +5,53 @@ Warmup kernels used during model execution.
|
||||
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
|
||||
happen during model execution.
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
|
||||
def kernel_warmup(model: torch.nn.Module, max_tokens: int):
|
||||
def kernel_warmup(worker: "Worker"):
|
||||
# Deep GEMM warmup
|
||||
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
|
||||
and is_deep_gemm_supported()
|
||||
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
|
||||
if do_deep_gemm_warmup:
|
||||
model = worker.get_model()
|
||||
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
||||
deep_gemm_warmup(model, max_tokens)
|
||||
|
||||
# FlashInfer autotune for Blackwell (SM 10.0) GPUs
|
||||
if has_flashinfer() and current_platform.is_device_capability(100):
|
||||
flashinfer_autotune(worker.model_runner)
|
||||
|
||||
|
||||
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
|
||||
"""
|
||||
Autotune FlashInfer operations.
|
||||
FlashInfer have many implementations for the same operation,
|
||||
autotuning runs benchmarks for each implementation and stores
|
||||
the results. The results are cached transparently and
|
||||
future calls to FlashInfer will use the best implementation.
|
||||
Without autotuning, FlashInfer will rely on heuristics, which may
|
||||
be significantly slower.
|
||||
"""
|
||||
from vllm.utils.flashinfer import autotune
|
||||
|
||||
with torch.inference_mode(), autotune():
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
# When autotuning with number of tokens m, flashinfer will autotune
|
||||
# operations for all number of tokens up to m.
|
||||
# So we only need to run with the max number of tokens.
|
||||
runner._dummy_run(runner.scheduler_config.max_num_batched_tokens,
|
||||
skip_eplb=True,
|
||||
is_profile=True)
|
||||
|
||||
@ -14,6 +14,7 @@ import os
|
||||
from typing import Any, Callable, NoReturn, Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -193,6 +194,75 @@ def use_trtllm_attention(
|
||||
return use_trtllm
|
||||
|
||||
|
||||
if has_flashinfer():
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::flashinfer_mm_fp4",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def flashinfer_mm_fp4(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
g_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import mm_fp4 as flashinfer_mm_fp4_
|
||||
return flashinfer_mm_fp4_(A,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
g_scale,
|
||||
dtype,
|
||||
block_size=16,
|
||||
backend=backend)
|
||||
|
||||
@torch.library.register_fake("vllm::flashinfer_mm_fp4", )
|
||||
def flashinfer_mm_fp4_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
g_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(A.shape[0],
|
||||
B.shape[1],
|
||||
dtype=dtype,
|
||||
device=A.device)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor, alpha: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
backend: str) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
|
||||
assert a.stride(-1) == 1 and b.stride(-1) == 1
|
||||
assert a.shape[1] == b.shape[1]
|
||||
assert block_scale_a.shape[1] == a.shape[1] // 8
|
||||
assert block_scale_b.shape[1] == b.shape[1] // 8
|
||||
|
||||
if backend == "cutlass":
|
||||
block_scale_a = block_scale_a.view(torch.uint8)
|
||||
block_scale_b = block_scale_b.view(torch.uint8)
|
||||
|
||||
return flashinfer_mm_fp4(
|
||||
a,
|
||||
b.t(),
|
||||
block_scale_a,
|
||||
block_scale_b.t(),
|
||||
alpha,
|
||||
out_dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
@ -205,4 +275,5 @@ __all__ = [
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_nvidia_artifactory",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
]
|
||||
|
||||
@ -310,6 +310,7 @@ class Worker(WorkerBase):
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size, skip_eplb=True)
|
||||
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
@ -340,8 +341,7 @@ class Worker(WorkerBase):
|
||||
hidden_states=last_hidden_states)
|
||||
|
||||
# Warmup kernels used during model execution
|
||||
kernel_warmup(self.get_model(),
|
||||
max_tokens=self.scheduler_config.max_num_batched_tokens)
|
||||
kernel_warmup(self)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user