Bump Flashinfer to v0.4.0 (#26326)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-10-09 14:58:44 +08:00 committed by GitHub
parent 0d7c3cb51d
commit 5e49c3e777
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 25 additions and 23 deletions

View File

@ -15,7 +15,7 @@ ARG PYTHON_VERSION=3.12
# Example:
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
# Important: We build with an old version of Ubuntu to maintain broad
# Important: We build with an old version of Ubuntu to maintain broad
# compatibility with other Linux OSes. The main reason for this is that the
# glibc version is baked into the distro, and binaries built with one glibc
# version are not backwards compatible with OSes that use an earlier version.
@ -371,7 +371,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
# Install FlashInfer from source
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
# Keep this in sync with "flashinfer" extra in setup.py
ARG FLASHINFER_GIT_REF="v0.3.1"
ARG FLASHINFER_GIT_REF="v0.4.0"
# Flag to control whether to compile FlashInfer AOT kernels
# Set to "true" to enable AOT compilation:
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
@ -392,7 +392,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
fi
pushd flashinfer
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ] && [ "${FLASHINFER_GIT_REF}" = "v0.3.1" ]; then
# NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh
echo "🏗️ Installing FlashInfer from pre-compiled wheel"
uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \

View File

@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
# build flashinfer for torch nightly from source around 10 mins
# release version: v0.3.1
# release version: v0.4.0
# todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
echo "git clone flashinfer..." \
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \
&& git checkout v0.3.1 \
&& git checkout v0.4.0 \
&& git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \
&& rm -rf build \

View File

@ -715,7 +715,7 @@ setup(
], # Required for audio processing
"video": [], # Kept for backwards compatibility
# FlashInfer should be updated together with the Dockerfile
"flashinfer": ["flashinfer-python==0.3.1"],
"flashinfer": ["flashinfer-python==0.4.0"],
# Optional deps for AMD FP4 quantization support
"petit-kernel": ["petit-kernel"],
},

View File

@ -7,9 +7,8 @@ import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
get_nvfp4_global_scale,
)
from vllm.platforms import current_platform
from vllm.utils import round_up
@ -171,13 +170,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
).to(torch.float32)
o_sf_scale = get_nvfp4_global_scale(output)
o_sf_scale_float = o_sf_scale.item()
# TRTLLM Decode
if o_quant_dtype == FP4_DTYPE:
@ -204,7 +202,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
window_left=window_left,
o_sf_scale=o_sf_scale,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
@ -361,13 +359,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
).to(torch.float32)
o_sf_scale = get_nvfp4_global_scale(output)
o_sf_scale_float = o_sf_scale.item()
# TRTLLM Prefill
if o_quant_dtype == FP4_DTYPE:
@ -398,7 +395,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
window_left=window_left,
o_sf_scale=o_sf_scale,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:

View File

@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype):
return values.reshape(m, n * 2).to(dtype=dtype)
def get_nvfp4_global_scale(a: torch.Tensor):
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
def quant_nvfp4_tensor(a: torch.Tensor):
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
torch.float32
)
a_global_scale = get_nvfp4_global_scale(a)
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
return a_quant, a_block_scale, a_global_scale

View File

@ -50,7 +50,7 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None):
with RemoteOpenAIServer(
model,
server_args,
max_wait_seconds=1000, # Due to FlashInfer compile
max_wait_seconds=1500, # Due to FlashInfer compile
override_hf_configs=dummy_hf_overrides,
) as server:
client = server.get_client()

View File

@ -1199,7 +1199,7 @@ def fast_plan_decode(
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try:
# Make sure we pass exactly 15 arguments for tensor core version
# Make sure we pass exactly 18 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
@ -1216,6 +1216,9 @@ def fast_plan_decode(
head_dim,
head_dim,
False, # causal
window_left,
-1, # fixed_split_size
False, # disable_split_kv
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e