remove attn output view kernel (#26680)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Boyuan Feng 2025-10-14 15:53:10 -07:00 committed by GitHub
parent ff4810ba73
commit a86b4c58e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 12 additions and 12 deletions

View File

@ -346,7 +346,7 @@ class Attention(nn.Module, AttentionLayerBase):
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
@ -705,7 +705,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.calc_kv_scales(q, kv_c_normed, k_pe)
if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
self,
q,
@ -722,7 +722,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,

View File

@ -530,7 +530,7 @@ class FlashAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
attn_type = self.attn_type

View File

@ -857,7 +857,7 @@ class FlashInferImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale

View File

@ -767,7 +767,7 @@ class FlexAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
# return torch.empty_like(query)

View File

@ -485,7 +485,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in

View File

@ -130,7 +130,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
assert attn_metadata.use_cascade is False

View File

@ -299,7 +299,7 @@ class RocmAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
assert attn_metadata.use_cascade is False

View File

@ -379,7 +379,7 @@ class TreeAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)

View File

@ -298,7 +298,7 @@ class TritonAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
assert attn_metadata.use_cascade is False

View File

@ -354,7 +354,7 @@ class XFormersAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)