mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
update flashinfer to v0.2.9rc1 (#21485)
Signed-off-by: Weiliang Liu <weiliangl@nvidia.com>
This commit is contained in:
parent
a6c7fb8cff
commit
2dd72d23d9
@ -386,7 +386,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
|||||||
|
|
||||||
# Install FlashInfer from source
|
# Install FlashInfer from source
|
||||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||||
ARG FLASHINFER_GIT_REF="v0.2.8"
|
ARG FLASHINFER_GIT_REF="v0.2.9rc1"
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||||
. /etc/environment
|
. /etc/environment
|
||||||
git clone --depth 1 --recursive --shallow-submodules \
|
git clone --depth 1 --recursive --shallow-submodules \
|
||||||
|
|||||||
@ -1169,16 +1169,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
query=decode_query,
|
query=decode_query,
|
||||||
kv_cache=kv_cache.permute(*stride_order),
|
kv_cache=kv_cache.permute(*stride_order),
|
||||||
workspace_buffer=workspace_buffer,
|
workspace_buffer=workspace_buffer,
|
||||||
num_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
scale=softmax_scale,
|
|
||||||
block_tables=attn_metadata.block_tables,
|
block_tables=attn_metadata.block_tables,
|
||||||
seq_lens=decode_meta.seq_lens_tensor,
|
seq_lens=decode_meta.seq_lens_tensor,
|
||||||
block_size=attn_metadata.page_size,
|
|
||||||
max_seq_len=attn_metadata.max_decode_seq_len,
|
max_seq_len=attn_metadata.max_decode_seq_len,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
bmm1_scale=layer._k_scale_float * softmax_scale,
|
||||||
k_scale=layer._k_scale_float,
|
bmm2_scale=layer._v_scale_float,
|
||||||
v_scale=layer._v_scale_float)
|
)
|
||||||
|
|
||||||
if prefill_output is None and decode_output is not None:
|
if prefill_output is None and decode_output is not None:
|
||||||
# Decode only batch.
|
# Decode only batch.
|
||||||
|
|||||||
@ -678,15 +678,10 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
query=decode_query,
|
query=decode_query,
|
||||||
kv_cache=kv_cache_permute,
|
kv_cache=kv_cache_permute,
|
||||||
workspace_buffer=attn_metadata.workspace_buffer,
|
workspace_buffer=attn_metadata.workspace_buffer,
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
scale=self.scale,
|
|
||||||
block_tables=block_tables_decode,
|
block_tables=block_tables_decode,
|
||||||
seq_lens=seq_lens_decode,
|
seq_lens=seq_lens_decode,
|
||||||
block_size=attn_metadata.page_size,
|
|
||||||
max_seq_len=attn_metadata.max_seq_len,
|
max_seq_len=attn_metadata.max_seq_len,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
bmm1_scale=layer._k_scale_float * self.scale,
|
||||||
k_scale=layer._k_scale_float,
|
bmm2_scale=layer._v_scale_float,
|
||||||
v_scale=layer._v_scale_float,
|
|
||||||
))
|
))
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user