mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 04:37:09 +08:00
[Core] Refactor Attention Take 2 (#3462)
This commit is contained in:
parent
b0dfa91dd7
commit
925f3332ca
@ -3,8 +3,7 @@ import pytest
|
||||
import time
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
|
||||
context_attention_fwd)
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
|
||||
@ -2,7 +2,10 @@
|
||||
|
||||
Run `pytest tests/samplers/test_beam_search.py --forked`.
|
||||
"""
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# FIXME(zhuohan): The test can not pass if we:
|
||||
# 1. Increase max_tokens to 256.
|
||||
@ -36,6 +39,10 @@ def test_beam_search_single_input(
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del vllm_model
|
||||
# NOTE(woosuk): For some reason, the following GC is required to avoid
|
||||
# GPU OOM errors in the following tests using `vllm_runner`.
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, _ = hf_outputs[i]
|
||||
|
||||
@ -34,19 +34,19 @@ def test_prepare_prompt(batch_size):
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += prompt_len
|
||||
(input_tokens, input_positions, input_metadata, return_prompt_lens, _, _,
|
||||
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert input_metadata.is_prompt is True
|
||||
assert torch.allclose(input_metadata.prompt_lens_tensor,
|
||||
assert attn_metadata.is_prompt is True
|
||||
assert torch.allclose(attn_metadata.prompt_lens_tensor,
|
||||
torch.tensor(prompt_lens, device=device))
|
||||
assert input_metadata.prompt_lens == prompt_lens
|
||||
assert input_metadata.num_prompt_tokens == sum(prompt_lens)
|
||||
assert input_metadata.num_generation_tokens == 0
|
||||
assert input_metadata.max_seq_len == max(prompt_lens)
|
||||
assert attn_metadata.prompt_lens == prompt_lens
|
||||
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
|
||||
assert attn_metadata.num_generation_tokens == 0
|
||||
assert attn_metadata.max_prompt_len == max(prompt_lens)
|
||||
|
||||
# Test subquery start locs.
|
||||
start_idx = 0
|
||||
@ -55,7 +55,7 @@ def test_prepare_prompt(batch_size):
|
||||
start_idx += prompt_len
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
input_metadata.subquery_start_loc,
|
||||
attn_metadata.subquery_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
# Test seq start locs. Note that for normal prefill it is
|
||||
@ -67,22 +67,22 @@ def test_prepare_prompt(batch_size):
|
||||
seq_start_loc.append(start_idx)
|
||||
|
||||
assert torch.allclose(
|
||||
input_metadata.seq_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
assert input_metadata.max_context_len is None
|
||||
assert attn_metadata.max_context_len is None
|
||||
assert torch.allclose(
|
||||
input_metadata.context_lens,
|
||||
torch.zeros(input_metadata.context_lens.shape[0],
|
||||
attn_metadata.context_lens,
|
||||
torch.zeros(attn_metadata.context_lens.shape[0],
|
||||
dtype=torch.int,
|
||||
device=device))
|
||||
|
||||
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device)
|
||||
assert torch.allclose(input_metadata.block_tables, expected)
|
||||
assert torch.allclose(attn_metadata.block_tables, expected)
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert input_metadata.use_cuda_graph is False
|
||||
assert input_metadata.kv_cache_dtype == "auto"
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
assert attn_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (sum(prompt_lens), )
|
||||
assert input_positions.shape == (sum(prompt_lens), )
|
||||
@ -140,34 +140,34 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
input_tokens, input_positions, input_metadata, _, _, _ = (
|
||||
input_tokens, input_positions, attn_metadata, _, _, _ = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert input_metadata.is_prompt is False
|
||||
assert input_metadata.prompt_lens is None
|
||||
assert input_metadata.num_prompt_tokens == 0
|
||||
assert input_metadata.num_generation_tokens == expected_bs
|
||||
assert input_metadata.max_seq_len is None
|
||||
assert input_metadata.subquery_start_loc is None
|
||||
assert input_metadata.seq_start_loc is None
|
||||
assert input_metadata.max_context_len == max(prompt_lens)
|
||||
assert attn_metadata.is_prompt is False
|
||||
assert attn_metadata.prompt_lens is None
|
||||
assert attn_metadata.num_prompt_tokens == 0
|
||||
assert attn_metadata.num_generation_tokens == expected_bs
|
||||
assert attn_metadata.max_prompt_len is None
|
||||
assert attn_metadata.subquery_start_loc is None
|
||||
assert attn_metadata.seq_start_loc is None
|
||||
assert attn_metadata.max_context_len == max(prompt_lens)
|
||||
assert torch.allclose(
|
||||
input_metadata.context_lens[:len(prompt_lens)],
|
||||
attn_metadata.context_lens[:len(prompt_lens)],
|
||||
torch.tensor(prompt_lens, dtype=torch.int, device=device))
|
||||
|
||||
# block table's first index corresponds to each batch, meaning in
|
||||
# decoding it is each token.
|
||||
assert input_metadata.block_tables.shape[0] == len(input_tokens)
|
||||
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
|
||||
# Block table's second dim correspondsd to each token's block number.
|
||||
# It is padded up to
|
||||
assert input_metadata.block_tables.shape[1] == (
|
||||
assert attn_metadata.block_tables.shape[1] == (
|
||||
model_runner.get_max_block_per_batch())
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert input_metadata.use_cuda_graph is True
|
||||
assert input_metadata.kv_cache_dtype == "auto"
|
||||
assert attn_metadata.use_cuda_graph is True
|
||||
assert attn_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (expected_bs, )
|
||||
assert input_positions.shape == (expected_bs, )
|
||||
|
||||
10
vllm/attention/__init__.py
Normal file
10
vllm/attention/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
__all__ = [
|
||||
"AttentionBackend",
|
||||
"AttentionMetadata",
|
||||
"Attention",
|
||||
"get_attn_backend",
|
||||
]
|
||||
85
vllm/attention/backends/abstract.py
Normal file
85
vllm/attention/backends/abstract.py
Normal file
@ -0,0 +1,85 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata:
|
||||
|
||||
def asdict_zerocopy(self) -> Dict[str, Any]:
|
||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
||||
# Note that if we add dataclasses as fields, they will need
|
||||
# similar handling.
|
||||
return {
|
||||
field.name: getattr(self, field.name)
|
||||
for field in fields(self)
|
||||
}
|
||||
|
||||
|
||||
class AttentionImpl(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
238
vllm/attention/backends/flash_attn.py
Normal file
238
vllm/attention/backends/flash_attn.py
Normal file
@ -0,0 +1,238 @@
|
||||
"""Attention layer with Flash and PagedAttention.
|
||||
|
||||
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
|
||||
XFormers backend. The duplicated code will be removed once we use flash-attn or
|
||||
flashinfer for all the attention operations.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
|
||||
return FlashAttentionMetadata(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seqlen ----------------------|
|
||||
# |- subquery_len -|
|
||||
|
||||
# WARNING(sang): context_len has different definition depending on if it is
|
||||
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
|
||||
# When it is for decoding, it includes a new token.
|
||||
|
||||
# Maximum subquery length in the batch.
|
||||
max_subquery_len: Optional[int]
|
||||
# Maximum prompt length in the batch.
|
||||
max_prompt_len: Optional[int]
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
subquery_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = ((sliding_window, sliding_window)
|
||||
if sliding_window is not None else (-1, -1))
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.kv_cache_dtype)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=attn_metadata.seq_start_loc,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_prompt_len,
|
||||
max_seqlen_k=attn_metadata.max_prompt_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.prompt_lens_tensor,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_subquery_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
@ -1,19 +1,127 @@
|
||||
"""Attention layer with xFormers and PagedAttention."""
|
||||
import importlib
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
BlockDiagonalCausalMask,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention.ops.paged_attn import (
|
||||
PagedAttentionImpl)
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class XFormersBackend:
|
||||
|
||||
class XFormersBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["XFormersImpl"]:
|
||||
return XFormersImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
|
||||
return XFormersMetadata(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for XFormersbackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seqlen ----------------------|
|
||||
# |- subquery_len -|
|
||||
|
||||
# WARNING(sang): context_len has different definition depending on if it is
|
||||
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
|
||||
# When it is for decoding, it includes a new token.
|
||||
|
||||
# Maximum subquery length in the batch.
|
||||
max_subquery_len: Optional[int]
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum prompt length in the batch.
|
||||
max_prompt_len: Optional[int]
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
subquery_start_loc: Optional[torch.Tensor]
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
|
||||
|
||||
class XFormersImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens --------------->|
|
||||
@ -50,22 +158,25 @@ class XFormersBackend:
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
|
||||
|
||||
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
self.use_ref_attention = _check_use_ref_attention()
|
||||
# AMD Radeon 7900 series (gfx1100) currently does not support xFormers
|
||||
# nor FlashAttention. As a temporary workaround, we use naive PyTorch
|
||||
# implementation of attention.
|
||||
self.use_naive_attention = _check_use_naive_attention()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: XFormersMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
@ -73,11 +184,8 @@ class XFormersBackend:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
@ -86,20 +194,24 @@ class XFormersBackend:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# If key_cache and value_cache are not provided, the new key and value
|
||||
# vectors will not be cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
|
||||
value_cache, input_metadata)
|
||||
if kv_cache is not None:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.kv_cache_dtype)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
# key_cache and value_cache are None when it is a profiling run.
|
||||
# block tables are empty if the prompt has never been computed.
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
# normal attention.
|
||||
# block tables are empty if the prompt does not have a cached
|
||||
# prefix.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
@ -118,13 +230,12 @@ class XFormersBackend:
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
if self.use_ref_attention:
|
||||
print("ref attention used.")
|
||||
if self.use_naive_attention:
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
for _, prompt_len in enumerate(input_metadata.prompt_lens):
|
||||
for _, prompt_len in enumerate(attn_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
out = _ref_masked_attention(
|
||||
out = _naive_masked_attention(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
@ -143,26 +254,33 @@ class XFormersBackend:
|
||||
# Use reshape instead.
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
output = self._run_memory_efficient_xformer_forward(
|
||||
query, key, value, input_metadata)
|
||||
output = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, attn_metadata)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = PagedAttentionImpl.forward_prefix(
|
||||
output = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.prompt_lens_tensor,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_subquery_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = PagedAttentionImpl.forward_decode(
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
@ -171,12 +289,12 @@ class XFormersBackend:
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_memory_efficient_xformer_forward(
|
||||
def _run_memory_efficient_xformers_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
attn_metadata: XFormersMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
@ -186,23 +304,23 @@ class XFormersBackend:
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
attn_metadata: Metadata for attention.
|
||||
"""
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if input_metadata.attn_bias is None:
|
||||
if attn_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
input_metadata.prompt_lens)
|
||||
attn_metadata.prompt_lens)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = [attn_bias]
|
||||
attn_metadata.attn_bias = [attn_bias]
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
attn_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, query.dtype,
|
||||
input_metadata)
|
||||
attn_metadata.prompt_lens)
|
||||
|
||||
op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
|
||||
is_hip()) else None
|
||||
@ -217,7 +335,7 @@ class XFormersBackend:
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias[0],
|
||||
attn_bias=attn_metadata.attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=op)
|
||||
@ -230,13 +348,13 @@ class XFormersBackend:
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
for i, prompt_len in enumerate(input_metadata.prompt_lens):
|
||||
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=input_metadata.attn_bias[i],
|
||||
attn_bias=attn_metadata.attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=op)
|
||||
@ -250,10 +368,10 @@ def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
input_metadata: InputMetadata,
|
||||
prompt_lens: List[int],
|
||||
) -> LowerTriangularMaskWithTensorBias:
|
||||
attn_biases = []
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
for prompt_len in prompt_lens:
|
||||
bias = torch.arange(prompt_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
@ -282,15 +400,19 @@ def _make_alibi_bias(
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _check_use_ref_attention() -> bool:
|
||||
def _check_use_naive_attention() -> bool:
|
||||
if not is_hip():
|
||||
return False
|
||||
# For ROCm, check whether flash attention is installed or not.
|
||||
# if not, use_ref_attention needs to be True
|
||||
return importlib.util.find_spec("flash_attn") is None
|
||||
has_flash_attn = importlib.util.find_spec("flash_attn") is None
|
||||
if not has_flash_attn:
|
||||
logger.warning("flash_attn is not installed. Using naive attention. "
|
||||
"This will take significantly more GPU memory.")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _ref_masked_attention(
|
||||
def _naive_masked_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
46
vllm/attention/layer.py
Normal file
46
vllm/attention/layer.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Attention layer."""
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.backend = get_attn_backend(torch.get_default_dtype())
|
||||
impl_cls = self.backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata)
|
||||
217
vllm/attention/ops/paged_attn.py
Normal file
217
vllm/attention/ops/paged_attn.py
Normal file
@ -0,0 +1,217 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._C import cache_ops
|
||||
from vllm._C import ops
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# (batch_size,). The length of context (tokens stored in KV cache) per
|
||||
# sequence. WARNING: When it is a prefill request, it doesn't include new
|
||||
# tokens. When it is for decoding, it includes a new token.
|
||||
context_lens: Optional[torch.Tensor]
|
||||
# Maximum context length in the batch.
|
||||
max_context_len: Optional[int]
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
kv_cache_dtype: str
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 80, 96, 112, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
cache_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = (max_context_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
subquery_start_loc: torch.Tensor,
|
||||
prompt_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_subquery_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# subquery_start_loc is (batch_size + 1,)
|
||||
subquery_start_loc[:-1],
|
||||
prompt_lens_tensor,
|
||||
context_lens,
|
||||
max_subquery_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
44
vllm/attention/selector.py
Normal file
44
vllm/attention/selector.py
Normal file
@ -0,0 +1,44 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_attn_backend(dtype: torch.dtype) -> AttentionBackend:
|
||||
if _can_use_flash_attn(dtype):
|
||||
logger.info("Using FlashAttention backend.")
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionBackend # noqa: F401
|
||||
return FlashAttentionBackend
|
||||
else:
|
||||
logger.info("Using XFormers backend.")
|
||||
from vllm.attention.backends.xformers import XFormersBackend # noqa: F401
|
||||
return XFormersBackend
|
||||
|
||||
|
||||
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
|
||||
if is_hip():
|
||||
# AMD GPUs.
|
||||
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
|
||||
return False
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info("Cannot use FlashAttention backend for Volta and Turing "
|
||||
"GPUs.")
|
||||
return False
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.info("Cannot use FlashAttention backend for dtype other than "
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
return False
|
||||
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ImportError:
|
||||
logger.info("flash_attn is not found.")
|
||||
return False
|
||||
return True
|
||||
@ -1,9 +1,7 @@
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"SamplingMetadata",
|
||||
"set_random_seed",
|
||||
]
|
||||
|
||||
@ -1,99 +0,0 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import TYPE_CHECKING, Optional, List, Any, Dict
|
||||
|
||||
import torch
|
||||
if TYPE_CHECKING:
|
||||
from xformers.ops.fmha.attn_bias import AttentionBias
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used in PagedAttention.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
"""
|
||||
Definition of context_len, subquery_len, and seqlen.
|
||||
|---------- N-1 iteration --------|
|
||||
|---------------- N iteration ---------------------|
|
||||
|- tokenA -|......................|-- newTokens ---|
|
||||
|---------- context_len ----------|
|
||||
|-------------------- seqlen ----------------------|
|
||||
|- subquery_len -|
|
||||
|
||||
WARNING: context_len has different definition depending on if it is
|
||||
prefill vs decoding. When it is prefill, it doesn't include new
|
||||
tokens. When it is for decoding, it includes a new token.
|
||||
"""
|
||||
|
||||
# Maximum subquery length in the batch.
|
||||
max_subquery_len: Optional[int]
|
||||
# Maximum context length in the batch.
|
||||
max_context_len: Optional[int]
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
subquery_start_loc: Optional[torch.Tensor]
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,). The length of context (tokens stored in KV cache) per
|
||||
# sequence. WARNING: When it is a prefill request, it doesn't include new
|
||||
# tokens. When it is for decoding, it includes a new token.
|
||||
context_lens: Optional[torch.Tensor]
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
use_cuda_graph: bool
|
||||
kv_cache_dtype: str
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List["AttentionBias"]] = None
|
||||
|
||||
# Cuda graph is only used for decoding now.
|
||||
if self.use_cuda_graph:
|
||||
assert self.num_prompt_tokens == 0
|
||||
|
||||
def asdict_zerocopy(self) -> Dict[str, Any]:
|
||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
||||
# Note that if we add dataclasses as fields, they will need
|
||||
# similar handling.
|
||||
return {
|
||||
field.name: getattr(self, field.name)
|
||||
for field in fields(self)
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
from vllm.model_executor.layers.attention.attention import Attention
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
]
|
||||
@ -1,85 +0,0 @@
|
||||
"""Attention layer."""
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Output the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if _use_flash_attn():
|
||||
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501
|
||||
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
|
||||
num_kv_heads, alibi_slopes,
|
||||
sliding_window)
|
||||
else:
|
||||
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501
|
||||
self.backend = XFormersBackend(num_heads, head_size, scale,
|
||||
num_kv_heads, alibi_slopes,
|
||||
sliding_window)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
return self.backend.forward(query, key, value, key_cache, value_cache,
|
||||
input_metadata)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _use_flash_attn() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ImportError:
|
||||
logger.info("flash_attn is not found. Using xformers backend.")
|
||||
return False
|
||||
|
||||
if is_hip():
|
||||
# AMD GPUs.
|
||||
return False
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info("flash_attn is not supported on Turing or older GPUs. "
|
||||
"Using xformers backend.")
|
||||
return False
|
||||
if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"flash_attn only supports torch.float16 or torch.bfloat16. "
|
||||
"Using xformers backend.")
|
||||
return False
|
||||
|
||||
logger.info("Using flash_attn backend.")
|
||||
return True
|
||||
@ -1,139 +0,0 @@
|
||||
"""Attention layer with Flash and PagedAttention."""
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention.ops.paged_attn import (
|
||||
PagedAttentionImpl)
|
||||
|
||||
|
||||
class FlashAttentionBackend:
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
self.sliding_window = ((self.sliding_window, self.sliding_window) if
|
||||
self.sliding_window is not None else (-1, -1))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# If key_cache and value_cache are not provided, the new key and value
|
||||
# vectors will not be cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
|
||||
value_cache, input_metadata)
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=input_metadata.seq_start_loc,
|
||||
cu_seqlens_k=input_metadata.seq_start_loc,
|
||||
max_seqlen_q=input_metadata.max_seq_len,
|
||||
max_seqlen_k=input_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = PagedAttentionImpl.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = PagedAttentionImpl.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
@ -1,139 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._C import cache_ops
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
|
||||
context_attention_fwd)
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
class PagedAttentionImpl:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 80, 96, 112, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
cache_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.slot_mapping.flatten(),
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (
|
||||
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = input_metadata.max_context_len <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.block_tables,
|
||||
# subquery_start_loc is (batch_size + 1,)
|
||||
input_metadata.subquery_start_loc[:-1],
|
||||
input_metadata.prompt_lens_tensor,
|
||||
input_metadata.context_lens,
|
||||
input_metadata.max_subquery_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
return output
|
||||
@ -25,9 +25,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
@ -45,8 +44,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
@ -170,15 +167,14 @@ class BaiChuanAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -217,8 +213,8 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -232,7 +228,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -267,8 +263,8 @@ class BaiChuanModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@ -278,7 +274,7 @@ class BaiChuanModel(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -303,11 +299,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -17,15 +17,14 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BloomConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
@ -117,14 +114,13 @@ class BloomAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
del position_ids # Unused.
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -181,8 +177,8 @@ class BloomBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
@ -198,7 +194,7 @@ class BloomBlock(nn.Module):
|
||||
position_ids=position_ids,
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
attention_output = attention_output + residual
|
||||
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||
@ -245,8 +241,8 @@ class BloomModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
hidden_states = self.word_embeddings_layernorm(hidden_states)
|
||||
@ -256,7 +252,7 @@ class BloomModel(nn.Module):
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -281,11 +277,11 @@ class BloomForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -2,15 +2,14 @@
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
@ -29,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
@ -99,20 +96,18 @@ class GLMAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
key_cache, value_cache = kv_cache
|
||||
context_layer = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
)
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
return attn_output
|
||||
@ -200,8 +195,8 @@ class GLMBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# hidden_states: [num_tokens, h]
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
@ -211,7 +206,7 @@ class GLMBlock(nn.Module):
|
||||
hidden_states=layernorm_output,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Residual connection.
|
||||
@ -264,8 +259,8 @@ class GLMTransformer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
@ -273,7 +268,7 @@ class GLMTransformer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_caches[i],
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
@ -306,8 +301,8 @@ class ChatGLMModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
@ -316,7 +311,7 @@ class ChatGLMModel(nn.Module):
|
||||
hidden_states=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@ -340,11 +335,11 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -21,15 +21,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Deepseek model."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
@ -51,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
|
||||
@ -239,14 +236,13 @@ class DeepseekAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -294,8 +290,8 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -309,7 +305,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -346,15 +342,15 @@ class DeepseekModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], input_metadata,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@ -379,11 +375,11 @@ class DeepseekForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -19,16 +19,15 @@
|
||||
"""PyTorch Falcon model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from transformers import FalconConfig as HF_FalconConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -48,7 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||
|
||||
|
||||
@ -177,8 +175,8 @@ class FalconAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
@ -186,8 +184,7 @@ class FalconAttention(nn.Module):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_rotary:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output, bias = self.dense(attn_output)
|
||||
return attn_output, bias
|
||||
|
||||
@ -263,8 +260,8 @@ class FalconDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
@ -279,7 +276,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=attention_layernorm_out,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if self.reduce_row_parallel_results and attention_bias is not None:
|
||||
attention_output += attention_bias
|
||||
@ -343,8 +340,8 @@ class FalconModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
@ -353,7 +350,7 @@ class FalconModel(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -378,14 +375,14 @@ class FalconForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -20,10 +20,9 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import GemmaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
|
||||
@ -133,14 +130,13 @@ class GemmaAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -177,8 +173,8 @@ class GemmaDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -192,7 +188,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -226,8 +222,8 @@ class GemmaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
@ -240,7 +236,7 @@ class GemmaModel(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -290,11 +286,11 @@ class GemmaForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -17,15 +17,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2Config
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
|
||||
@ -79,14 +76,12 @@ class GPT2Attention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -144,15 +139,15 @@ class GPT2Block(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
@ -190,8 +185,8 @@ class GPT2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
@ -199,7 +194,7 @@ class GPT2Model(nn.Module):
|
||||
|
||||
for i in range(len(self.h)):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -224,11 +219,11 @@ class GPT2LMHeadModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -18,15 +18,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTBigCodeConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
@ -94,8 +91,8 @@ class GPTBigCodeAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split(
|
||||
@ -105,9 +102,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -165,15 +160,15 @@ class GPTBigCodeBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
@ -211,8 +206,8 @@ class GPTBigCodeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
@ -220,7 +215,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
|
||||
for i in range(len(self.h)):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -245,11 +240,11 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -16,15 +16,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-J model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTJConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPTJAttention(nn.Module):
|
||||
|
||||
@ -93,14 +90,13 @@ class GPTJAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -154,8 +150,8 @@ class GPTJBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
@ -163,7 +159,7 @@ class GPTJBlock(nn.Module):
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
hidden_states = attn_output + mlp_output + residual
|
||||
@ -192,8 +188,8 @@ class GPTJModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
@ -202,7 +198,7 @@ class GPTJModel(nn.Module):
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -232,11 +228,11 @@ class GPTJForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -16,15 +16,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
|
||||
@ -94,14 +91,13 @@ class GPTNeoXAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -155,15 +151,15 @@ class GPTNeoXLayer(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.input_layernorm(hidden_states)
|
||||
attn_output = self.attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=attn_input,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
@ -208,8 +204,8 @@ class GPTNeoXModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
@ -218,7 +214,7 @@ class GPTNeoXModel(nn.Module):
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
@ -246,11 +242,11 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -5,9 +5,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class InternLM2MLP(nn.Module):
|
||||
|
||||
@ -124,14 +121,13 @@ class InternLM2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.wqkv(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.wo(attn_output)
|
||||
return output
|
||||
|
||||
@ -172,8 +168,8 @@ class InternLMDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -187,7 +183,7 @@ class InternLMDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -221,8 +217,8 @@ class InternLM2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
residual = None
|
||||
@ -232,7 +228,7 @@ class InternLM2Model(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -258,11 +254,11 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -20,14 +20,13 @@
|
||||
"""Inference-only Jais model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.transformers_utils.configs import JAISConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import (
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class SwiGLUActivation(nn.Module):
|
||||
|
||||
@ -122,14 +119,12 @@ class JAISAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -196,15 +191,15 @@ class JAISBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
@ -248,8 +243,8 @@ class JAISModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
if self.wpe is not None:
|
||||
@ -262,7 +257,7 @@ class JAISModel(nn.Module):
|
||||
|
||||
for i in range(len(self.h)):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -293,11 +288,11 @@ class JAISLMHeadModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
@ -348,4 +343,4 @@ class JAISLMHeadModel(nn.Module):
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -27,10 +27,9 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
@ -48,8 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
@ -150,14 +147,13 @@ class LlamaAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -203,8 +199,8 @@ class LlamaDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -218,7 +214,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -258,8 +254,8 @@ class LlamaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@ -269,7 +265,7 @@ class LlamaModel(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -336,11 +332,11 @@ class LlamaForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -21,15 +21,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
@ -51,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||
@ -209,14 +206,13 @@ class MixtralAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -254,8 +250,8 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -269,7 +265,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -309,15 +305,15 @@ class MixtralModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], input_metadata,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@ -377,11 +373,11 @@ class MixtralForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -31,8 +31,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
ReplicatedLinear,
|
||||
@ -52,8 +51,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
|
||||
@ -227,14 +224,13 @@ class MixtralAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -269,8 +265,8 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -284,7 +280,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -319,15 +315,15 @@ class MixtralModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], input_metadata,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@ -352,11 +348,11 @@ class MixtralForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _get_alibi_slopes(
|
||||
total_num_heads: int,
|
||||
@ -116,8 +113,8 @@ class MPTAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
del position_ids # unused.
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
@ -127,8 +124,7 @@ class MPTAttention(nn.Module):
|
||||
if self.qk_ln:
|
||||
q = self.q_ln(q)
|
||||
k = self.k_ln(k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -184,15 +180,15 @@ class MPTBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
x = self.norm_1(hidden_states)
|
||||
x = self.attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=x,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = hidden_states + x
|
||||
x = self.norm_2(hidden_states)
|
||||
@ -230,8 +226,8 @@ class MPTModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.blocks)):
|
||||
@ -240,7 +236,7 @@ class MPTModel(nn.Module):
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -267,11 +263,11 @@ class MPTForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -42,8 +42,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
@ -67,8 +66,6 @@ from vllm.sequence import SamplerOutput
|
||||
# this model must need this dependency
|
||||
from hf_olmo import OLMoConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
@ -146,16 +143,15 @@ class OlmoAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.attn_norm(hidden_states)
|
||||
qkv, _ = self.att_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.config.rope:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.attn_out(attn_output)
|
||||
return output
|
||||
|
||||
@ -241,12 +237,12 @@ class OlmoBlock(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Attention block.
|
||||
og_x = hidden_states
|
||||
x = self.attn(positions, hidden_states, kv_cache, input_metadata)
|
||||
x = self.attn(positions, hidden_states, kv_cache, attn_metadata)
|
||||
x = x + og_x
|
||||
|
||||
# MLP block.
|
||||
@ -296,8 +292,8 @@ class OlmoModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
||||
@ -313,7 +309,7 @@ class OlmoModel(nn.Module):
|
||||
positions,
|
||||
x,
|
||||
kv_caches[block_idx],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
# Apply final layer norm.
|
||||
@ -344,14 +340,14 @@ class OLMoForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -17,15 +17,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OPT model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
|
||||
@ -97,14 +94,12 @@ class OPTAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -152,8 +147,8 @@ class OPTDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -162,7 +157,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata)
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
@ -241,8 +236,8 @@ class OPTDecoder(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
pos_embeds = self.embed_positions(positions)
|
||||
@ -252,7 +247,7 @@ class OPTDecoder(nn.Module):
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -275,10 +270,10 @@ class OPTModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
return self.decoder(input_ids, positions, kv_caches, input_metadata)
|
||||
return self.decoder(input_ids, positions, kv_caches, attn_metadata)
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
@ -300,11 +295,11 @@ class OPTForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -10,9 +10,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -29,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class OrionMLP(nn.Module):
|
||||
|
||||
@ -128,14 +125,13 @@ class OrionAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -178,8 +174,8 @@ class OrionDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -189,7 +185,7 @@ class OrionDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
@ -227,8 +223,8 @@ class OrionModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@ -238,7 +234,7 @@ class OrionModel(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@ -264,11 +260,11 @@ class OrionForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -35,15 +35,14 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -60,8 +59,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class PhiAttention(nn.Module):
|
||||
|
||||
@ -115,14 +112,13 @@ class PhiAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -172,8 +168,8 @@ class PhiLayer(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -181,7 +177,7 @@ class PhiLayer(nn.Module):
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
||||
@ -209,8 +205,8 @@ class PhiModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
@ -219,7 +215,7 @@ class PhiModel(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
@ -248,11 +244,11 @@ class PhiForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -10,9 +10,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
@ -30,8 +29,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
|
||||
@ -111,15 +108,13 @@ class QWenAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.c_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -153,8 +148,8 @@ class QWenBlock(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -167,7 +162,7 @@ class QWenBlock(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -201,8 +196,8 @@ class QWenModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
residual = None
|
||||
@ -212,7 +207,7 @@ class QWenModel(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
@ -238,11 +233,11 @@ class QWenLMHeadModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -28,9 +28,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.config import LoRAConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
|
||||
@ -147,14 +144,13 @@ class Qwen2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -197,8 +193,8 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -212,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -248,8 +244,8 @@ class Qwen2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@ -259,7 +255,7 @@ class Qwen2Model(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -315,11 +311,11 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -25,9 +25,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -44,8 +43,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
|
||||
@ -134,14 +131,13 @@ class StablelmAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -166,8 +162,8 @@ class StablelmDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -176,7 +172,7 @@ class StablelmDecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -211,8 +207,8 @@ class StableLMEpochModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
@ -221,7 +217,7 @@ class StableLMEpochModel(nn.Module):
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
@ -246,11 +242,11 @@ class StablelmForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -18,15 +18,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Starcoder2 model."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Starcoder2Config
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -43,8 +42,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class Starcoder2Attention(nn.Module):
|
||||
|
||||
@ -111,14 +108,13 @@ class Starcoder2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -171,8 +167,8 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -181,7 +177,7 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -217,14 +213,14 @@ class Starcoder2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states, kv_caches[i],
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -258,11 +254,11 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -431,7 +431,7 @@ class SequenceGroup:
|
||||
|
||||
|
||||
class SequenceGroupMetadata:
|
||||
"""Metadata for a sequence group. Used to create `InputMetadata`.
|
||||
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
"""CacheEngine class for managing the KV cache."""
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_pin_memory_available, STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class CacheEngine:
|
||||
"""Manages the KV cache.
|
||||
@ -43,95 +42,43 @@ class CacheEngine:
|
||||
else:
|
||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(model_config.dtype)
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self.allocate_gpu_cache()
|
||||
self.cpu_cache = self.allocate_cpu_cache()
|
||||
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
|
||||
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
|
||||
|
||||
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
||||
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
x = 16 // element_size
|
||||
return (
|
||||
self.num_heads,
|
||||
self.head_size // x,
|
||||
self.block_size,
|
||||
x,
|
||||
)
|
||||
|
||||
def get_value_block_shape(self) -> Tuple[int, int, int]:
|
||||
return (
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
def allocate_gpu_cache(self) -> List[KVCache]:
|
||||
gpu_cache: List[KVCache] = []
|
||||
key_block_shape = self.get_key_block_shape()
|
||||
value_block_shape = self.get_value_block_shape()
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *key_block_shape),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *value_block_shape),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
gpu_cache.append((key_blocks, value_blocks))
|
||||
return gpu_cache
|
||||
|
||||
def allocate_cpu_cache(self) -> List[KVCache]:
|
||||
cpu_cache: List[KVCache] = []
|
||||
key_block_shape = self.get_key_block_shape()
|
||||
value_block_shape = self.get_value_block_shape()
|
||||
pin_memory = is_pin_memory_available()
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_cpu_blocks, *key_block_shape),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device="cpu",
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_cpu_blocks, *value_block_shape),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device="cpu",
|
||||
)
|
||||
cpu_cache.append((key_blocks, value_blocks))
|
||||
return cpu_cache
|
||||
|
||||
def _swap(
|
||||
def _allocate_kv_cache(
|
||||
self,
|
||||
src: List[KVCache],
|
||||
dst: List[KVCache],
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
from vllm._C import cache_ops
|
||||
|
||||
for i in range(self.num_layers):
|
||||
src_key_cache, src_value_cache = src[i]
|
||||
dst_key_cache, dst_value_cache = dst[i]
|
||||
# Copy the key blocks.
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
# Copy the value blocks.
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
num_blocks: int,
|
||||
device: str,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
kv_cache.append(
|
||||
torch.empty(kv_cache_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device))
|
||||
return kv_cache
|
||||
|
||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
|
||||
for i in range(self.num_layers):
|
||||
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
|
||||
src_to_dst)
|
||||
|
||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
|
||||
for i in range(self.num_layers):
|
||||
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
||||
src_to_dst)
|
||||
|
||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
||||
from vllm._C import cache_ops
|
||||
|
||||
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
|
||||
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
||||
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
|
||||
@ -6,10 +6,11 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import InputMetadata, SamplingMetadata
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
@ -28,7 +29,6 @@ from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
_PAD_SLOT_ID = -1
|
||||
LORA_WARMUP_RANK = 8
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
@ -85,6 +85,9 @@ class ModelRunner:
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.dtype if model_config is not None else None)
|
||||
|
||||
def load_model(self) -> None:
|
||||
with CudaMemoryProfiler() as m:
|
||||
self.model = get_model(self.model_config,
|
||||
@ -127,8 +130,8 @@ class ModelRunner:
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
|
||||
List[int], List[int], Set[LoRARequest]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
List[int], List[int], List[int], Set[LoRARequest]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
@ -216,7 +219,7 @@ class ModelRunner:
|
||||
slot_mapping.append(slot)
|
||||
|
||||
max_subquery_len = max(subquery_lens)
|
||||
max_seq_len = max(prompt_lens)
|
||||
max_prompt_len = max(prompt_lens)
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
assert max_subquery_len > 0
|
||||
|
||||
@ -270,7 +273,7 @@ class ModelRunner:
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=prompt_lens,
|
||||
@ -279,7 +282,7 @@ class ModelRunner:
|
||||
num_generation_tokens=0,
|
||||
max_subquery_len=max_subquery_len,
|
||||
max_context_len=None,
|
||||
max_seq_len=max_seq_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
subquery_start_loc=subquery_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens_tensor,
|
||||
@ -287,15 +290,15 @@ class ModelRunner:
|
||||
use_cuda_graph=False,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (input_tokens, input_positions, input_metadata, prompt_lens,
|
||||
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||
lora_requests)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
|
||||
Set[LoRARequest]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
List[int], Set[LoRARequest]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
@ -401,7 +404,7 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=None,
|
||||
@ -410,7 +413,7 @@ class ModelRunner:
|
||||
num_generation_tokens=len(input_tokens),
|
||||
max_subquery_len=None,
|
||||
max_context_len=max_context_len,
|
||||
max_seq_len=None,
|
||||
max_prompt_len=None,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens=context_lens,
|
||||
@ -418,7 +421,7 @@ class ModelRunner:
|
||||
use_cuda_graph=use_captured_graph,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (input_tokens, input_positions, input_metadata,
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
lora_index_mapping, lora_prompt_mapping, lora_requests)
|
||||
|
||||
def _prepare_sample(
|
||||
@ -522,7 +525,7 @@ class ModelRunner:
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Set[int], LoRAMapping]:
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
@ -530,11 +533,11 @@ class ModelRunner:
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, input_metadata, prompt_lens,
|
||||
(input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions, input_metadata,
|
||||
(input_tokens, input_positions, attn_metadata,
|
||||
lora_index_mapping, lora_prompt_mapping,
|
||||
lora_requests) = self._prepare_decode(seq_group_metadata_list)
|
||||
prompt_lens = []
|
||||
@ -560,7 +563,7 @@ class ModelRunner:
|
||||
"lora_requests": lora_requests,
|
||||
"lora_mapping": lora_mapping,
|
||||
}
|
||||
metadata_dict.update(input_metadata.asdict_zerocopy())
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
@ -570,7 +573,7 @@ class ModelRunner:
|
||||
"selected_token_indices")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
input_metadata = InputMetadata(**metadata_dict)
|
||||
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
seq_data=None,
|
||||
@ -581,16 +584,16 @@ class ModelRunner:
|
||||
perform_sampling=False,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, input_metadata,
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, lora_requests, lora_mapping)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, input_metadata, sampling_metadata,
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||
lora_requests,
|
||||
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
@ -598,7 +601,7 @@ class ModelRunner:
|
||||
self.set_active_loras(lora_requests, lora_mapping)
|
||||
|
||||
# Execute the model.
|
||||
if input_metadata.use_cuda_graph:
|
||||
if attn_metadata.use_cuda_graph:
|
||||
graph_batch_size = input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
@ -607,7 +610,7 @@ class ModelRunner:
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
@ -673,7 +676,7 @@ class ModelRunner:
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [(None, None)] * num_layers
|
||||
kv_caches = [None] * num_layers
|
||||
self.execute_model(seqs, kv_caches)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
@ -705,7 +708,7 @@ class ModelRunner:
|
||||
return self.lora_manager.list_loras()
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
||||
def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
|
||||
"""Cuda graph capture a model.
|
||||
|
||||
Note that CUDA graph's performance gain is negligible if number
|
||||
@ -759,8 +762,8 @@ class ModelRunner:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
# Create dummy input_metadata.
|
||||
input_metadata = InputMetadata(
|
||||
# Create dummy attn_metadata.
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
prompt_lens=None,
|
||||
@ -769,7 +772,7 @@ class ModelRunner:
|
||||
num_generation_tokens=batch_size,
|
||||
max_subquery_len=None,
|
||||
max_context_len=self.max_context_len_to_capture,
|
||||
max_seq_len=None,
|
||||
max_prompt_len=None,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens=context_lens[:batch_size],
|
||||
@ -790,7 +793,7 @@ class ModelRunner:
|
||||
input_tokens[:batch_size],
|
||||
input_positions[:batch_size],
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
@ -826,8 +829,8 @@ class CUDAGraphRunner:
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
memory_pool,
|
||||
) -> None:
|
||||
assert self.graph is None
|
||||
@ -839,7 +842,7 @@ class CUDAGraphRunner:
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -853,7 +856,7 @@ class CUDAGraphRunner:
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -862,9 +865,9 @@ class CUDAGraphRunner:
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": input_metadata.slot_mapping,
|
||||
"context_lens": input_metadata.context_lens,
|
||||
"block_tables": input_metadata.block_tables,
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"context_lens": attn_metadata.context_lens,
|
||||
"block_tables": attn_metadata.block_tables,
|
||||
}
|
||||
self.output_buffers = {"hidden_states": hidden_states}
|
||||
return
|
||||
@ -873,8 +876,8 @@ class CUDAGraphRunner:
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# KV caches are fixed tensors, so we don't need to copy them.
|
||||
del kv_caches
|
||||
@ -882,11 +885,11 @@ class CUDAGraphRunner:
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
|
||||
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
||||
non_blocking=True)
|
||||
self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
|
||||
self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
|
||||
non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
|
||||
self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
|
||||
non_blocking=True)
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
|
||||
@ -128,6 +128,9 @@ class Worker:
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||
assert peak_memory > 0, (
|
||||
"Error in memory profiling. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
cache_block_size = self.get_cache_block_size_bytes(
|
||||
block_size, cache_dtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user