mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:55:00 +08:00
[torch.compile] use empty tensor instead of None for profiling (#8875)
This commit is contained in:
parent
8df2dc3c88
commit
a9b15c606f
@ -136,7 +136,9 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
|||||||
)
|
)
|
||||||
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
||||||
# Caller does not require a KV cache
|
# Caller does not require a KV cache
|
||||||
return TestResources(scale, attn_backend, attn, None)
|
return TestResources(
|
||||||
|
scale, attn_backend, attn,
|
||||||
|
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
|
||||||
|
|
||||||
# Construct KV cache
|
# Construct KV cache
|
||||||
kv_cache = make_kv_cache(test_pt.num_blocks,
|
kv_cache = make_kv_cache(test_pt.num_blocks,
|
||||||
@ -620,7 +622,9 @@ def _run_encoder_attention_test(
|
|||||||
return attn.forward(packed_qkv.query,
|
return attn.forward(packed_qkv.query,
|
||||||
packed_qkv.key,
|
packed_qkv.key,
|
||||||
packed_qkv.value,
|
packed_qkv.value,
|
||||||
None,
|
torch.tensor([],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=packed_qkv.query.device),
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
attn_type=attn_type)
|
attn_type=attn_type)
|
||||||
|
|
||||||
|
|||||||
@ -357,6 +357,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||||
|
for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
@ -373,7 +375,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache.numel() > 0:
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
@ -399,7 +401,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
# prompt, and they have the same length.
|
# prompt, and they have the same length.
|
||||||
|
|
||||||
assert kv_cache is None \
|
assert kv_cache.numel() == 0 \
|
||||||
or prefill_meta.block_tables is None \
|
or prefill_meta.block_tables is None \
|
||||||
or prefill_meta.block_tables.numel() == 0, \
|
or prefill_meta.block_tables.numel() == 0, \
|
||||||
"Does not support prefix-enabled attention."
|
"Does not support prefix-enabled attention."
|
||||||
|
|||||||
@ -665,6 +665,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||||
|
for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
@ -685,7 +687,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache.numel() > 0:
|
||||||
key_cache = kv_cache[0]
|
key_cache = kv_cache[0]
|
||||||
value_cache = kv_cache[1]
|
value_cache = kv_cache[1]
|
||||||
|
|
||||||
@ -722,7 +724,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
if (kv_cache is None or prefill_meta.block_tables is None
|
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||||
or prefill_meta.block_tables.numel() == 0):
|
or prefill_meta.block_tables.numel() == 0):
|
||||||
# normal attention
|
# normal attention
|
||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
|
|||||||
@ -746,7 +746,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: FlashInferMetadata,
|
attn_metadata: FlashInferMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
@ -770,7 +770,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
if attn_metadata.num_decode_tokens > 0:
|
if attn_metadata.num_decode_tokens > 0:
|
||||||
assert attn_metadata.num_prefill_tokens == 0, (
|
assert attn_metadata.num_prefill_tokens == 0, (
|
||||||
"Chunked prefill is not supported with flashinfer yet.")
|
"Chunked prefill is not supported with flashinfer yet.")
|
||||||
if kv_cache is not None:
|
if kv_cache.numel() > 0:
|
||||||
# Use the same reshape and cache kernel as flash attention.
|
# Use the same reshape and cache kernel as flash attention.
|
||||||
ops.reshape_and_cache_flash(
|
ops.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
@ -796,7 +796,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
# when kv_cache is not provided.
|
# when kv_cache is not provided.
|
||||||
# This happens when vllm runs the profiling to
|
# This happens when vllm runs the profiling to
|
||||||
# determine the number of blocks.
|
# determine the number of blocks.
|
||||||
if kv_cache is None:
|
if kv_cache.numel() == 0:
|
||||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
output = torch.ops.vllm.flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
|
|||||||
@ -167,7 +167,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
@ -180,6 +180,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||||
|
for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
@ -196,7 +198,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache.numel() > 0:
|
||||||
key_cache, value_cache = self.split_kv_cache(
|
key_cache, value_cache = self.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
ipex_ops.reshape_and_cache(
|
ipex_ops.reshape_and_cache(
|
||||||
@ -212,7 +214,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
if (kv_cache.numel() == 0
|
||||||
|
or attn_metadata.block_tables.numel() == 0):
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
|||||||
@ -143,7 +143,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attn_metadata: PallasMetadata,
|
attn_metadata: PallasMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
@ -155,8 +155,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||||
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||||
|
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
||||||
|
with shape [0] for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [batch_size, seq_len, num_heads * head_size]
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
@ -173,7 +175,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||||
self.head_size)
|
self.head_size)
|
||||||
|
|
||||||
if kv_cache[0] is not None:
|
if kv_cache[0].numel() > 0:
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
key_cache, value_cache = kv_cache
|
key_cache, value_cache = kv_cache
|
||||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||||
@ -205,7 +207,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
output = output.permute(0, 2, 1, 3)
|
output = output.permute(0, 2, 1, 3)
|
||||||
else:
|
else:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
assert kv_cache is not None
|
assert kv_cache[0].numel() > 0
|
||||||
|
|
||||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||||
if self.megacore_mode == "batch" and batch_size % 2 != 0:
|
if self.megacore_mode == "batch" and batch_size % 2 != 0:
|
||||||
|
|||||||
@ -396,6 +396,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||||
|
for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
@ -412,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache.numel() > 0:
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
@ -449,7 +451,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
assert prefill_meta.seq_lens is not None
|
assert prefill_meta.seq_lens is not None
|
||||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||||
# triton attention
|
# triton attention
|
||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
# prompt, and they have the same length.
|
# prompt, and they have the same length.
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
@ -164,6 +164,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||||
|
for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
@ -180,7 +182,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache.numel() > 0:
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
@ -191,7 +193,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
if (kv_cache.numel() == 0
|
||||||
|
or attn_metadata.block_tables.numel() == 0):
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
|||||||
@ -445,7 +445,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor],
|
key: Optional[torch.Tensor],
|
||||||
value: Optional[torch.Tensor],
|
value: Optional[torch.Tensor],
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: "XFormersMetadata",
|
attn_metadata: "XFormersMetadata",
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
@ -489,6 +489,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||||
|
for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
attn_type: Select attention type, between encoder attention,
|
attn_type: Select attention type, between encoder attention,
|
||||||
decoder self-attention, or encoder/decoder cross-
|
decoder self-attention, or encoder/decoder cross-
|
||||||
@ -522,7 +524,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
# which KV cache memory-mapping & which
|
# which KV cache memory-mapping & which
|
||||||
# seqlen datastructures we utilize
|
# seqlen datastructures we utilize
|
||||||
|
|
||||||
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
|
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||||
# KV-cache during decoder-self- or
|
# KV-cache during decoder-self- or
|
||||||
# encoder-decoder-cross-attention, but not
|
# encoder-decoder-cross-attention, but not
|
||||||
# during encoder attention.
|
# during encoder attention.
|
||||||
@ -588,7 +590,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||||
# normal attention.
|
# normal attention.
|
||||||
# block tables are empty if the prompt does not have a cached
|
# block tables are empty if the prompt does not have a cached
|
||||||
# prefix.
|
# prefix.
|
||||||
|
|||||||
@ -97,7 +97,13 @@ class EmbeddingModelRunner(
|
|||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
kv_caches = [None] * num_layers
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
|
# a placeholder (it has wide hardware support).
|
||||||
|
kv_caches = [
|
||||||
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||||
|
] * num_layers
|
||||||
|
|
||||||
execute_model_kwargs = {
|
execute_model_kwargs = {
|
||||||
"input_ids":
|
"input_ids":
|
||||||
|
|||||||
@ -340,7 +340,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
|
|
||||||
# Run the model with the dummy inputs.
|
# Run the model with the dummy inputs.
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
kv_caches = [None] * num_layers
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
|
# a placeholder (it has wide hardware support).
|
||||||
|
kv_caches = [
|
||||||
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||||
|
] * num_layers
|
||||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||||
model_input = self.prepare_model_input(
|
model_input = self.prepare_model_input(
|
||||||
seqs, finished_requests_ids=finished_requests_ids)
|
seqs, finished_requests_ids=finished_requests_ids)
|
||||||
|
|||||||
@ -1223,7 +1223,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
# Run the model with the dummy inputs.
|
# Run the model with the dummy inputs.
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
kv_caches = [None] * num_layers
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
|
# a placeholder (it has wide hardware support).
|
||||||
|
kv_caches = [
|
||||||
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||||
|
] * num_layers
|
||||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||||
model_input = self.prepare_model_input(
|
model_input = self.prepare_model_input(
|
||||||
seqs, finished_requests_ids=finished_requests_ids)
|
seqs, finished_requests_ids=finished_requests_ids)
|
||||||
|
|||||||
@ -714,7 +714,7 @@ class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
|||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
p: torch.Tensor,
|
p: torch.Tensor,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Executes the forward pass of the model and samples the next token.
|
"""Executes the forward pass of the model and samples the next token.
|
||||||
|
|
||||||
@ -745,7 +745,7 @@ class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Skip this in memory profiling at initialization.
|
# Skip this in memory profiling at initialization.
|
||||||
if kv_caches[0][0] is not None:
|
if kv_caches[0][0].numel() > 0:
|
||||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||||
|
|||||||
@ -115,7 +115,15 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
head_size = self.model_config.get_head_size()
|
head_size = self.model_config.get_head_size()
|
||||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
|
||||||
kv_caches = [(None, None) for _ in range(num_layers)]
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
|
# a placeholder (it has wide hardware support).
|
||||||
|
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
||||||
|
device=self.device),
|
||||||
|
torch.tensor([], dtype=torch.float32,
|
||||||
|
device=self.device))
|
||||||
|
for _ in range(num_layers)]
|
||||||
self.model_runner._dummy_run(
|
self.model_runner._dummy_run(
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||||
|
|||||||
@ -464,7 +464,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
|||||||
|
|
||||||
# Run the model with the dummy inputs.
|
# Run the model with the dummy inputs.
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
kv_caches = [None] * num_layers
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
|
# a placeholder (it has wide hardware support).
|
||||||
|
kv_caches = [
|
||||||
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||||
|
] * num_layers
|
||||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||||
model_input = self.prepare_model_input(
|
model_input = self.prepare_model_input(
|
||||||
seqs, finished_requests_ids=finished_requests_ids)
|
seqs, finished_requests_ids=finished_requests_ids)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user