mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:15:40 +08:00
[torch.compile] hide slicing under custom op for inductor (#8384)
This commit is contained in:
parent
42ffba11ad
commit
7de49aa86c
@ -16,5 +16,7 @@ def test_full_graph(model):
|
|||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
sampling_params = SamplingParams(temperature=0)
|
sampling_params = SamplingParams(temperature=0)
|
||||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
|
llm = LLM(model="meta-llama/Meta-Llama-3-8B",
|
||||||
|
enforce_eager=True,
|
||||||
|
load_format="dummy")
|
||||||
llm.generate(prompts, sampling_params)
|
llm.generate(prompts, sampling_params)
|
||||||
|
|||||||
@ -122,6 +122,40 @@ def _(
|
|||||||
return torch.empty_like(decode_query)
|
return torch.empty_like(decode_query)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("vllm::reshape_and_cache_flash",
|
||||||
|
mutates_args=["kv_cache"])
|
||||||
|
def reshape_and_cache_flash(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
) -> None:
|
||||||
|
"""Inductor cannot deal with inplace operations on views.
|
||||||
|
See https://github.com/pytorch/pytorch/issues/131192
|
||||||
|
and https://github.com/pytorch/pytorch/issues/130174
|
||||||
|
This is a workaround to hide the view operation from the inductor.
|
||||||
|
"""
|
||||||
|
return torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
|
||||||
|
k_scale, v_scale)
|
||||||
|
|
||||||
|
|
||||||
|
@reshape_and_cache_flash.register_fake # type: ignore
|
||||||
|
def _(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# If kv_cache is not provided, the new key and value tensors are
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
# not cached. This happens during the initial memory profiling run.
|
# not cached. This happens during the initial memory profiling run.
|
||||||
ops.reshape_and_cache_flash(
|
torch.ops.vllm.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
kv_cache,
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping.flatten(),
|
attn_metadata.slot_mapping.flatten(),
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
k_scale,
|
||||||
@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||||
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
# Query for decode. KV is not needed because it is already cached.
|
||||||
decode_query = query[num_prefill_tokens:]
|
decode_query = query[num_prefill_tokens:]
|
||||||
# QKV for prefill.
|
# QKV for prefill.
|
||||||
@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert query.shape[0] == num_prefill_tokens
|
assert query.shape[0] == num_prefill_tokens
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
assert decode_query.shape[0] == num_decode_tokens
|
||||||
|
|
||||||
|
prefill_output: Optional[torch.Tensor] = None
|
||||||
|
decode_output: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
if (kv_cache is None or prefill_meta.block_tables is None
|
if (kv_cache is None or prefill_meta.block_tables is None
|
||||||
@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# normal attention
|
# normal attention
|
||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
# prompt, and they have the same length.
|
# prompt, and they have the same length.
|
||||||
out = torch.ops.vllm.flash_attn_varlen_func(
|
prefill_output = torch.ops.vllm.flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
@ -701,42 +736,44 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
|
||||||
output[:num_prefill_tokens] = out
|
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
assert prefill_meta.seq_lens is not None
|
assert prefill_meta.seq_lens is not None
|
||||||
max_seq_len = max(prefill_meta.seq_lens)
|
max_seq_len = max(prefill_meta.seq_lens)
|
||||||
output[:
|
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
|
||||||
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
|
q=query,
|
||||||
q=query,
|
k=key_cache,
|
||||||
k=key_cache,
|
v=value_cache,
|
||||||
v=value_cache,
|
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
max_seqlen_q=prefill_meta.max_query_len,
|
||||||
max_seqlen_q=prefill_meta.max_query_len,
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
max_seqlen_k=max_seq_len,
|
||||||
max_seqlen_k=max_seq_len,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=True,
|
|
||||||
alibi_slopes=self.alibi_slopes,
|
|
||||||
block_table=prefill_meta.block_tables,
|
|
||||||
softcap=self.logits_soft_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
|
||||||
# Decoding run.
|
|
||||||
output[
|
|
||||||
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
|
|
||||||
decode_query.unsqueeze(1),
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
block_table=decode_meta.block_tables,
|
|
||||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
block_table=prefill_meta.block_tables,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
).squeeze(1)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
# Decoding run.
|
||||||
|
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||||
|
decode_query.unsqueeze(1),
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_table=decode_meta.block_tables,
|
||||||
|
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
if prefill_output is None:
|
||||||
|
assert decode_output is not None
|
||||||
|
return decode_output.view(num_decode_tokens, hidden_size)
|
||||||
|
if decode_output is None:
|
||||||
|
assert prefill_output is not None
|
||||||
|
return prefill_output.view(num_prefill_tokens, hidden_size)
|
||||||
|
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user