Remove xformers

This commit is contained in:
Woosuk Kwon 2023-02-24 08:36:16 +00:00
parent afdbe5d373
commit 7f22f90e8c

View File

@ -2,7 +2,6 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import xformers.ops as xops
from cacheflow import ops from cacheflow import ops
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
@ -14,8 +13,20 @@ class OPTCacheFlowAttention(nn.Module):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
# Shape-agnostic attention mask. def _masked_attention(
self.attention_mask = xops.LowerTriangularMask() self,
query: torch.Tensor, # [num_queries, num_heads, head_size]
key: torch.Tensor, # [num_keys, num_heads, head_size]
value: torch.Tensor, # [num_keys, num_heads, head_size]
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
) -> torch.Tensor: # [num_queries, num_heads, head_size]
query = query * self.scale
attn = torch.einsum('qhd,khd->hqk', query, key)
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
out = torch.einsum('hqk,khd->qhd', attn, value)
return out
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
@ -24,13 +35,11 @@ class OPTCacheFlowAttention(nn.Module):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
) -> None: ) -> None:
query = query.unsqueeze(0) # FIXME(woosuk): Replace this with a custom op call.
key = key.unsqueeze(0) attention_mask = torch.triu(
value = value.unsqueeze(0) torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5
out = xops.memory_efficient_attention( attention_mask = attention_mask.to(dtype=query.dtype, device=query.device)
query, key, value, attn_bias=self.attention_mask, scale=self.scale) out = self._masked_attention(query, key, value, attention_mask)
out = out.squeeze(0)
# FIXME(woosuk): Directly write the attention output.
output.copy_(out, non_blocking=True) output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
@ -64,15 +73,10 @@ class OPTCacheFlowAttention(nn.Module):
v = value_cache[block_number, :, block_offset, :] v = value_cache[block_number, :, block_offset, :]
values.append(v) values.append(v)
keys = torch.stack(keys, dim=0) keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0) values = torch.stack(values, dim=0)
q = q.unsqueeze(0) out = self._masked_attention(q, keys, values)
keys = keys.unsqueeze(0)
values = values.unsqueeze(0)
out = xops.memory_efficient_attention(
q, keys, values, scale=self.scale)
out = out.view(num_heads, head_size) out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True) output[i].copy_(out, non_blocking=True)