mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 02:23:06 +08:00
parent
bc1bdecebf
commit
4c3aac51e1
@ -156,9 +156,13 @@ class Attention(nn.Module):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
if self.calculate_kv_scales and \
|
||||
attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(key, value)
|
||||
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
|
||||
# directly, use `self.kv_cache` and
|
||||
# `get_forward_context().attn_metadata` instead.
|
||||
if self.calculate_kv_scales:
|
||||
ctx_attn_metadata = get_forward_context().attn_metadata
|
||||
if ctx_attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(key, value)
|
||||
if self.use_output:
|
||||
output = torch.empty_like(query)
|
||||
hidden_size = query.size(-1)
|
||||
@ -172,15 +176,27 @@ class Attention(nn.Module):
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
unified_attention_with_output(query, key, value, output,
|
||||
self.layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
ctx_attn_metadata = forward_context.attn_metadata
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
ctx_attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
forward_context = get_forward_context()
|
||||
ctx_attn_metadata = forward_context.attn_metadata
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, ctx_attn_metadata)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user