mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:55:01 +08:00
[core] use forward context for flash infer (#9097)
This commit is contained in:
parent
5df1834895
commit
f4dd830e09
@ -26,6 +26,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||
make_tensor_with_pad)
|
||||
|
||||
@ -761,73 +762,132 @@ class FlashInferImpl(AttentionImpl):
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if attn_metadata.num_prefill_tokens > 0:
|
||||
assert attn_metadata.num_decode_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
assert attn_metadata.num_prefill_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if kv_cache.numel() > 0:
|
||||
# Use the same reshape and cache kernel as flash attention.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
return torch.ops.vllm.unified_flash_infer(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::unified_flash_infer",
|
||||
mutates_args=["kv_cache"])
|
||||
def unified_flash_infer(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashInferMetadata)
|
||||
attn_metadata: FlashInferMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if attn_metadata.num_prefill_tokens > 0:
|
||||
assert attn_metadata.num_decode_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
assert attn_metadata.num_prefill_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if kv_cache.numel() > 0:
|
||||
# Use the same reshape and cache kernel as flash attention.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
query = query.contiguous() # Flashinfer requires query to be contiguous
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# We will use flash attention for prefill
|
||||
# when kv_cache is not provided.
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache.numel() == 0:
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
query = query.contiguous(
|
||||
) # Flashinfer requires query to be contiguous
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# We will use flash attention for prefill
|
||||
# when kv_cache is not provided.
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache.numel() == 0:
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
output = prefill_meta.prefill_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
causal=True)
|
||||
else:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
sm_scale=self.scale,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
output = prefill_meta.prefill_wrapper.forward(
|
||||
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
|
||||
else:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
sm_scale=softmax_scale,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
@unified_flash_infer.register_fake
|
||||
def _(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user