mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:44:57 +08:00
[torch.compile] use empty tensor instead of None for profiling (#8875)
This commit is contained in:
parent
8df2dc3c88
commit
a9b15c606f
@ -136,7 +136,9 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||
)
|
||||
if test_pt.num_blocks is None or test_pt.num_heads is None:
|
||||
# Caller does not require a KV cache
|
||||
return TestResources(scale, attn_backend, attn, None)
|
||||
return TestResources(
|
||||
scale, attn_backend, attn,
|
||||
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
|
||||
|
||||
# Construct KV cache
|
||||
kv_cache = make_kv_cache(test_pt.num_blocks,
|
||||
@ -620,7 +622,9 @@ def _run_encoder_attention_test(
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
None,
|
||||
torch.tensor([],
|
||||
dtype=torch.float32,
|
||||
device=packed_qkv.query.device),
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
@ -357,6 +357,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@ -373,7 +375,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
@ -399,7 +401,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
|
||||
assert kv_cache is None \
|
||||
assert kv_cache.numel() == 0 \
|
||||
or prefill_meta.block_tables is None \
|
||||
or prefill_meta.block_tables.numel() == 0, \
|
||||
"Does not support prefix-enabled attention."
|
||||
|
||||
@ -665,6 +665,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@ -685,7 +687,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
|
||||
@ -722,7 +724,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache is None or prefill_meta.block_tables is None
|
||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
|
||||
@ -746,7 +746,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
@ -770,7 +770,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
assert attn_metadata.num_prefill_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if kv_cache is not None:
|
||||
if kv_cache.numel() > 0:
|
||||
# Use the same reshape and cache kernel as flash attention.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
@ -796,7 +796,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
# when kv_cache is not provided.
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache is None:
|
||||
if kv_cache.numel() == 0:
|
||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
|
||||
@ -167,7 +167,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
@ -180,6 +180,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@ -196,7 +198,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = self.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
ipex_ops.reshape_and_cache(
|
||||
@ -212,7 +214,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||
if (kv_cache.numel() == 0
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
|
||||
@ -143,7 +143,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: PallasMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
@ -155,8 +155,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
||||
with shape [0] for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
@ -173,7 +175,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
if kv_cache[0] is not None:
|
||||
if kv_cache[0].numel() > 0:
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
key_cache, value_cache = kv_cache
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
@ -205,7 +207,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Decoding run.
|
||||
assert kv_cache is not None
|
||||
assert kv_cache[0].numel() > 0
|
||||
|
||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||
if self.megacore_mode == "batch" and batch_size % 2 != 0:
|
||||
|
||||
@ -396,6 +396,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@ -412,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
@ -449,7 +451,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
assert prefill_meta.seq_lens is not None
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
|
||||
@ -151,7 +151,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
@ -164,6 +164,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@ -180,7 +182,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
@ -191,7 +193,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||
if (kv_cache.numel() == 0
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
|
||||
@ -445,7 +445,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
@ -489,6 +489,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
@ -522,7 +524,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
@ -588,7 +590,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention.
|
||||
# block tables are empty if the prompt does not have a cached
|
||||
# prefix.
|
||||
|
||||
@ -97,7 +97,13 @@ class EmbeddingModelRunner(
|
||||
model_executable = self.model
|
||||
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
] * num_layers
|
||||
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
|
||||
@ -340,7 +340,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
] * num_layers
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
|
||||
@ -1223,7 +1223,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
] * num_layers
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
|
||||
@ -714,7 +714,7 @@ class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
num_samples: int,
|
||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
|
||||
@ -745,7 +745,7 @@ class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||
)
|
||||
|
||||
# Skip this in memory profiling at initialization.
|
||||
if kv_caches[0][0] is not None:
|
||||
if kv_caches[0][0].numel() > 0:
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
|
||||
@ -115,7 +115,15 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
|
||||
kv_caches = [(None, None) for _ in range(num_layers)]
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
||||
device=self.device),
|
||||
torch.tensor([], dtype=torch.float32,
|
||||
device=self.device))
|
||||
for _ in range(num_layers)]
|
||||
self.model_runner._dummy_run(
|
||||
batch_size=1,
|
||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||
|
||||
@ -464,7 +464,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
] * num_layers
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user