mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-09 10:35:44 +08:00
Remove xformers
This commit is contained in:
parent
afdbe5d373
commit
7f22f90e8c
@ -2,7 +2,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import xformers.ops as xops
|
|
||||||
|
|
||||||
from cacheflow import ops
|
from cacheflow import ops
|
||||||
from cacheflow.models import InputMetadata
|
from cacheflow.models import InputMetadata
|
||||||
@ -14,8 +13,20 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
# Shape-agnostic attention mask.
|
def _masked_attention(
|
||||||
self.attention_mask = xops.LowerTriangularMask()
|
self,
|
||||||
|
query: torch.Tensor, # [num_queries, num_heads, head_size]
|
||||||
|
key: torch.Tensor, # [num_keys, num_heads, head_size]
|
||||||
|
value: torch.Tensor, # [num_keys, num_heads, head_size]
|
||||||
|
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
|
||||||
|
) -> torch.Tensor: # [num_queries, num_heads, head_size]
|
||||||
|
query = query * self.scale
|
||||||
|
attn = torch.einsum('qhd,khd->hqk', query, key)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn = attn + attn_mask
|
||||||
|
attn = torch.softmax(attn, dim=-1)
|
||||||
|
out = torch.einsum('hqk,khd->qhd', attn, value)
|
||||||
|
return out
|
||||||
|
|
||||||
def multi_query_kv_attention(
|
def multi_query_kv_attention(
|
||||||
self,
|
self,
|
||||||
@ -24,13 +35,11 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
query = query.unsqueeze(0)
|
# FIXME(woosuk): Replace this with a custom op call.
|
||||||
key = key.unsqueeze(0)
|
attention_mask = torch.triu(
|
||||||
value = value.unsqueeze(0)
|
torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5
|
||||||
out = xops.memory_efficient_attention(
|
attention_mask = attention_mask.to(dtype=query.dtype, device=query.device)
|
||||||
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
|
out = self._masked_attention(query, key, value, attention_mask)
|
||||||
out = out.squeeze(0)
|
|
||||||
# FIXME(woosuk): Directly write the attention output.
|
|
||||||
output.copy_(out, non_blocking=True)
|
output.copy_(out, non_blocking=True)
|
||||||
|
|
||||||
def single_query_cached_kv_attention(
|
def single_query_cached_kv_attention(
|
||||||
@ -64,15 +73,10 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
|
|
||||||
v = value_cache[block_number, :, block_offset, :]
|
v = value_cache[block_number, :, block_offset, :]
|
||||||
values.append(v)
|
values.append(v)
|
||||||
|
|
||||||
keys = torch.stack(keys, dim=0)
|
keys = torch.stack(keys, dim=0)
|
||||||
values = torch.stack(values, dim=0)
|
values = torch.stack(values, dim=0)
|
||||||
|
|
||||||
q = q.unsqueeze(0)
|
out = self._masked_attention(q, keys, values)
|
||||||
keys = keys.unsqueeze(0)
|
|
||||||
values = values.unsqueeze(0)
|
|
||||||
out = xops.memory_efficient_attention(
|
|
||||||
q, keys, values, scale=self.scale)
|
|
||||||
out = out.view(num_heads, head_size)
|
out = out.view(num_heads, head_size)
|
||||||
output[i].copy_(out, non_blocking=True)
|
output[i].copy_(out, non_blocking=True)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user