update flashinfer to v0.2.9rc1 (#21485)

Signed-off-by: Weiliang Liu <weiliangl@nvidia.com>
This commit is contained in:
weiliang 2025-07-25 05:06:11 +08:00 committed by GitHub
parent a6c7fb8cff
commit 2dd72d23d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 15 deletions

View File

@ -386,7 +386,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
# Install FlashInfer from source
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
ARG FLASHINFER_GIT_REF="v0.2.8"
ARG FLASHINFER_GIT_REF="v0.2.9rc1"
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
. /etc/environment
git clone --depth 1 --recursive --shallow-submodules \

View File

@ -1169,16 +1169,12 @@ class FlashInferImpl(AttentionImpl):
query=decode_query,
kv_cache=kv_cache.permute(*stride_order),
workspace_buffer=workspace_buffer,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
scale=softmax_scale,
block_tables=attn_metadata.block_tables,
seq_lens=decode_meta.seq_lens_tensor,
block_size=attn_metadata.page_size,
max_seq_len=attn_metadata.max_decode_seq_len,
kv_cache_dtype=kv_cache_dtype,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float)
bmm1_scale=layer._k_scale_float * softmax_scale,
bmm2_scale=layer._v_scale_float,
)
if prefill_output is None and decode_output is not None:
# Decode only batch.

View File

@ -678,15 +678,10 @@ class FlashInferImpl(AttentionImpl):
query=decode_query,
kv_cache=kv_cache_permute,
workspace_buffer=attn_metadata.workspace_buffer,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
scale=self.scale,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
block_size=attn_metadata.page_size,
max_seq_len=attn_metadata.max_seq_len,
kv_cache_dtype=self.kv_cache_dtype,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
))
return output_padded