mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:24:53 +08:00
[misc] add forward context for attention (#9029)
This commit is contained in:
parent
63e39937f9
commit
9aaf14c62e
@ -3,9 +3,9 @@ from typing import List, Optional, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.attention.backends.flash_attn # noqa: F401
|
|
||||||
from tests.kernels.utils import opcheck
|
|
||||||
from vllm.utils import seed_everything
|
from vllm.utils import seed_everything
|
||||||
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||||
|
flash_attn_with_kvcache)
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
@ -112,10 +112,10 @@ def test_flash_attn_with_paged_kv(
|
|||||||
(num_seqs, max_num_blocks_per_seq),
|
(num_seqs, max_num_blocks_per_seq),
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
output = torch.ops.vllm.flash_attn_with_kvcache(
|
output = flash_attn_with_kvcache(
|
||||||
decode_query=query.unsqueeze(1),
|
q=query.unsqueeze(1),
|
||||||
key_cache=key_cache,
|
k_cache=key_cache,
|
||||||
value_cache=value_cache,
|
v_cache=value_cache,
|
||||||
softmax_scale=scale,
|
softmax_scale=scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
block_table=block_tables,
|
block_table=block_tables,
|
||||||
@ -123,25 +123,6 @@ def test_flash_attn_with_paged_kv(
|
|||||||
softcap=soft_cap if soft_cap is not None else 0,
|
softcap=soft_cap if soft_cap is not None else 0,
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
|
|
||||||
if num_blocks <= 2048:
|
|
||||||
test_utils = ["test_faketensor", "test_schema"]
|
|
||||||
else:
|
|
||||||
test_utils = ["test_faketensor"]
|
|
||||||
|
|
||||||
opcheck(torch.ops.vllm.flash_attn_with_kvcache,
|
|
||||||
args=tuple(),
|
|
||||||
kwargs=dict(
|
|
||||||
decode_query=query.unsqueeze(1),
|
|
||||||
key_cache=key_cache,
|
|
||||||
value_cache=value_cache,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=True,
|
|
||||||
block_table=block_tables,
|
|
||||||
cache_seqlens=kv_lens_tensor,
|
|
||||||
softcap=soft_cap if soft_cap is not None else 0,
|
|
||||||
),
|
|
||||||
test_utils=test_utils)
|
|
||||||
|
|
||||||
ref_output = ref_paged_attn(
|
ref_output = ref_paged_attn(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=key_cache,
|
key_cache=key_cache,
|
||||||
@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
|
|||||||
(num_seqs, max_num_blocks_per_seq),
|
(num_seqs, max_num_blocks_per_seq),
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
|
|||||||
softcap=soft_cap if soft_cap is not None else 0,
|
softcap=soft_cap if soft_cap is not None else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_blocks <= 2048:
|
|
||||||
test_utils = ["test_faketensor", "test_schema"]
|
|
||||||
else:
|
|
||||||
test_utils = ["test_faketensor"]
|
|
||||||
|
|
||||||
opcheck(torch.ops.vllm.flash_attn_varlen_func,
|
|
||||||
args=tuple(),
|
|
||||||
kwargs=dict(
|
|
||||||
q=query,
|
|
||||||
k=key_cache,
|
|
||||||
v=value_cache,
|
|
||||||
cu_seqlens_q=cu_query_lens,
|
|
||||||
cu_seqlens_k=cu_kv_lens,
|
|
||||||
max_seqlen_q=max_query_len,
|
|
||||||
max_seqlen_k=max_kv_len,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
block_table=block_tables,
|
|
||||||
softcap=soft_cap if soft_cap is not None else 0,
|
|
||||||
),
|
|
||||||
test_utils=test_utils)
|
|
||||||
|
|
||||||
ref_output = ref_paged_attn(
|
ref_output = ref_paged_attn(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=key_cache,
|
key_cache=key_cache,
|
||||||
|
|||||||
@ -13,152 +13,15 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
|||||||
compute_slot_mapping,
|
compute_slot_mapping,
|
||||||
compute_slot_mapping_start_idx,
|
compute_slot_mapping_start_idx,
|
||||||
is_block_tables_empty)
|
is_block_tables_empty)
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
|
||||||
# yapf: disable
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||||
from vllm.vllm_flash_attn import (
|
flash_attn_with_kvcache)
|
||||||
flash_attn_varlen_func as _flash_attn_varlen_func)
|
|
||||||
from vllm.vllm_flash_attn import (
|
|
||||||
flash_attn_with_kvcache as _flash_attn_with_kvcache)
|
|
||||||
|
|
||||||
# yapf: enable
|
|
||||||
|
|
||||||
|
|
||||||
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
|
|
||||||
def flash_attn_varlen_func(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
cu_seqlens_q: torch.Tensor,
|
|
||||||
cu_seqlens_k: torch.Tensor,
|
|
||||||
max_seqlen_q: int,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
softmax_scale: Optional[float] = None,
|
|
||||||
causal: bool = False,
|
|
||||||
window_size: Optional[List[int]] = None,
|
|
||||||
softcap: float = 0.0,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
block_table: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# custom op does not support tuple input
|
|
||||||
real_window_size: Tuple[int, int]
|
|
||||||
if window_size is None:
|
|
||||||
real_window_size = (-1, -1)
|
|
||||||
else:
|
|
||||||
assert len(window_size) == 2
|
|
||||||
real_window_size = (window_size[0], window_size[1])
|
|
||||||
return _flash_attn_varlen_func(
|
|
||||||
q=q,
|
|
||||||
k=k,
|
|
||||||
v=v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=causal,
|
|
||||||
window_size=real_window_size,
|
|
||||||
softcap=softcap,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
block_table=block_table,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@flash_attn_varlen_func.register_fake # type: ignore
|
|
||||||
def _(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
cu_seqlens_q: torch.Tensor,
|
|
||||||
cu_seqlens_k: torch.Tensor,
|
|
||||||
max_seqlen_q: int,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
softmax_scale: Optional[float] = None,
|
|
||||||
causal: bool = False,
|
|
||||||
window_size: Optional[List[int]] = None,
|
|
||||||
softcap: float = 0.0,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
block_table: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty_like(q)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
|
|
||||||
def flash_attn_with_kvcache(
|
|
||||||
decode_query: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
cache_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
block_table: Optional[torch.Tensor] = None,
|
|
||||||
softmax_scale: Optional[float] = None,
|
|
||||||
causal: bool = False,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
softcap: float = 0.0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return _flash_attn_with_kvcache(
|
|
||||||
decode_query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
cache_seqlens=cache_seqlens,
|
|
||||||
block_table=block_table,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=causal,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
softcap=softcap,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@flash_attn_with_kvcache.register_fake # type: ignore
|
|
||||||
def _(
|
|
||||||
decode_query: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
cache_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
block_table: Optional[torch.Tensor] = None,
|
|
||||||
softmax_scale: Optional[float] = None,
|
|
||||||
causal: bool = False,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
softcap: float = 0.0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty_like(decode_query)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.library.custom_op("vllm::reshape_and_cache_flash",
|
|
||||||
mutates_args=["kv_cache"])
|
|
||||||
def reshape_and_cache_flash(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
k_scale: float,
|
|
||||||
v_scale: float,
|
|
||||||
) -> None:
|
|
||||||
"""Inductor cannot deal with inplace operations on views.
|
|
||||||
See https://github.com/pytorch/pytorch/issues/131192
|
|
||||||
and https://github.com/pytorch/pytorch/issues/130174
|
|
||||||
This is a workaround to hide the view operation from the inductor.
|
|
||||||
"""
|
|
||||||
return torch.ops._C_cache_ops.reshape_and_cache_flash(
|
|
||||||
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
|
|
||||||
k_scale, v_scale)
|
|
||||||
|
|
||||||
|
|
||||||
@reshape_and_cache_flash.register_fake # type: ignore
|
|
||||||
def _(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
k_scale: float,
|
|
||||||
v_scale: float,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
@ -721,11 +584,55 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
|
output = torch.ops.vllm.unified_flash_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
self.num_kv_heads,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
self.scale,
|
||||||
|
self.sliding_window,
|
||||||
|
self.alibi_slopes,
|
||||||
|
self.logits_soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("vllm::unified_flash_attention",
|
||||||
|
mutates_args=["kv_cache"])
|
||||||
|
def unified_flash_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
softmax_scale: float,
|
||||||
|
window_size: Optional[List[int]] = None,
|
||||||
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
current_metadata = get_forward_context()
|
||||||
|
assert current_metadata is not None
|
||||||
|
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||||
|
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, num_heads, head_size)
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, num_kv_heads, head_size)
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, num_kv_heads, head_size)
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
key_cache = kv_cache[0]
|
key_cache = kv_cache[0]
|
||||||
@ -734,12 +641,13 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# If kv_cache is not provided, the new key and value tensors are
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
# not cached. This happens during the initial memory profiling run.
|
# not cached. This happens during the initial memory profiling run.
|
||||||
torch.ops.vllm.reshape_and_cache_flash(
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
kv_cache,
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
attn_metadata.slot_mapping.flatten(),
|
attn_metadata.slot_mapping.flatten(),
|
||||||
self.kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
)
|
)
|
||||||
@ -771,7 +679,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# normal attention
|
# normal attention
|
||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
# prompt, and they have the same length.
|
# prompt, and they have the same length.
|
||||||
prefill_output = torch.ops.vllm.flash_attn_varlen_func(
|
prefill_output = flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
@ -779,17 +687,17 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
window_size=self.sliding_window,
|
window_size=window_size,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
assert prefill_meta.seq_lens is not None
|
assert prefill_meta.seq_lens is not None
|
||||||
max_seq_len = max(prefill_meta.seq_lens)
|
max_seq_len = max(prefill_meta.seq_lens)
|
||||||
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
|
prefill_output = flash_attn_varlen_func( # noqa
|
||||||
q=query,
|
q=query,
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
@ -797,30 +705,29 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
max_seqlen_q=prefill_meta.max_query_len,
|
max_seqlen_q=prefill_meta.max_query_len,
|
||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
max_seqlen_k=max_seq_len,
|
max_seqlen_k=max_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
block_table=prefill_meta.block_tables,
|
block_table=prefill_meta.block_tables,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
_, num_head, head_dim = decode_query.shape
|
_, num_head, head_dim = decode_query.shape
|
||||||
decode_query = decode_query.reshape(-1,
|
decode_query = decode_query.reshape(-1, decode_meta.decode_query_len,
|
||||||
decode_meta.decode_query_len,
|
|
||||||
num_head, head_dim)
|
num_head, head_dim)
|
||||||
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
|
decode_output = flash_attn_with_kvcache(
|
||||||
decode_query,
|
q=decode_query,
|
||||||
key_cache,
|
k_cache=key_cache,
|
||||||
value_cache,
|
v_cache=value_cache,
|
||||||
block_table=decode_meta.block_tables,
|
block_table=decode_meta.block_tables,
|
||||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
)
|
).squeeze(1)
|
||||||
|
|
||||||
if prefill_output is None:
|
if prefill_output is None:
|
||||||
assert decode_output is not None
|
assert decode_output is not None
|
||||||
@ -836,3 +743,23 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
decode_output = decode_output.squeeze(1)
|
decode_output = decode_output.squeeze(1)
|
||||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
@unified_flash_attention.register_fake
|
||||||
|
def _(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
softmax_scale: float,
|
||||||
|
window_size: Optional[List[int]] = None,
|
||||||
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(query)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ try:
|
|||||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
|
||||||
import vllm.attention.backends.flash_attn # noqa
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||||
except ImportError:
|
except ImportError:
|
||||||
BatchDecodeWithPagedKVCacheWrapper = None
|
BatchDecodeWithPagedKVCacheWrapper = None
|
||||||
@ -799,7 +799,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
# This happens when vllm runs the profiling to
|
# This happens when vllm runs the profiling to
|
||||||
# determine the number of blocks.
|
# determine the number of blocks.
|
||||||
if kv_cache.numel() == 0:
|
if kv_cache.numel() == 0:
|
||||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
|
|||||||
22
vllm/forward_context.py
Normal file
22
vllm/forward_context.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_forward_context: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_forward_context() -> Any:
|
||||||
|
"""Get the current forward context."""
|
||||||
|
return _forward_context
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_forward_context(context: Any):
|
||||||
|
"""A context manager that stores the current forward context,
|
||||||
|
can be attention metadata, etc."""
|
||||||
|
global _forward_context
|
||||||
|
prev_context = _forward_context
|
||||||
|
_forward_context = context
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_forward_context = prev_context
|
||||||
@ -2,6 +2,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -291,6 +292,7 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
if previous_hidden_states is not None else {}
|
if previous_hidden_states is not None else {}
|
||||||
|
|
||||||
# Run model
|
# Run model
|
||||||
|
with set_forward_context(model_input.attn_metadata):
|
||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig)
|
PromptAdapterConfig, SchedulerConfig)
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.multimodal import MultiModalInputs
|
from vllm.multimodal import MultiModalInputs
|
||||||
@ -119,6 +120,7 @@ class EmbeddingModelRunner(
|
|||||||
device=self.device),
|
device=self.device),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
with set_forward_context(model_input.attn_metadata):
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Only perform pooling in the driver worker.
|
# Only perform pooling in the driver worker.
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig)
|
PromptAdapterConfig, SchedulerConfig)
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
@ -198,6 +199,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
} if self.has_seqlen_agnostic else {}
|
} if self.has_seqlen_agnostic else {}
|
||||||
|
|
||||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
with set_forward_context(model_input.attn_metadata):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
|||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
@ -1499,6 +1500,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self._update_inputs_to_capture_for_enc_dec_model(
|
self._update_inputs_to_capture_for_enc_dec_model(
|
||||||
capture_inputs)
|
capture_inputs)
|
||||||
|
|
||||||
|
with set_forward_context(attn_metadata):
|
||||||
graph_runner.capture(**capture_inputs)
|
graph_runner.capture(**capture_inputs)
|
||||||
self.graph_memory_pool = graph_runner.graph.pool()
|
self.graph_memory_pool = graph_runner.graph.pool()
|
||||||
self.graph_runners[virtual_engine][batch_size] = (
|
self.graph_runners[virtual_engine][batch_size] = (
|
||||||
@ -1641,6 +1643,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||||
model_forward_start.record()
|
model_forward_start.record()
|
||||||
|
|
||||||
|
with set_forward_context(model_input.attn_metadata):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user