mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 21:44:39 +08:00
remove attn output view kernel (#26680)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
ff4810ba73
commit
a86b4c58e8
@ -346,7 +346,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
@ -705,7 +705,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
||||
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
self.impl.forward(
|
||||
self,
|
||||
q,
|
||||
@ -722,7 +722,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
else:
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
torch.ops.vllm.unified_mla_attention_with_output(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
||||
@ -530,7 +530,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
attn_type = self.attn_type
|
||||
|
||||
|
||||
@ -857,7 +857,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
|
||||
@ -767,7 +767,7 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
|
||||
# return torch.empty_like(query)
|
||||
|
||||
|
||||
@ -485,7 +485,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
|
||||
@ -130,7 +130,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
|
||||
@ -299,7 +299,7 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
|
||||
@ -379,7 +379,7 @@ class TreeAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
@ -298,7 +298,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
|
||||
@ -354,7 +354,7 @@ class XFormersAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user