mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:25:01 +08:00
Support BERTModel (first encoder-only embedding model) (#9056)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Feldman <afeldman@neuralmagic.com> Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: laishzh <laishengzhang@gmail.com> Co-authored-by: Max de Bayser <maxdebayser@gmail.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
bb76538bbd
commit
343f8e0905
@ -6,21 +6,31 @@ import pytest
|
|||||||
|
|
||||||
from ..utils import check_embeddings_close
|
from ..utils import check_embeddings_close
|
||||||
|
|
||||||
|
# Model, Guard
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"intfloat/e5-mistral-7b-instruct",
|
"intfloat/e5-mistral-7b-instruct",
|
||||||
|
"BAAI/bge-base-en-v1.5",
|
||||||
"BAAI/bge-multilingual-gemma2",
|
"BAAI/bge-multilingual-gemma2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ENCODER_ONLY = [
|
||||||
|
"BAAI/bge-base-en-v1.5",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
def test_models(
|
def test_models(
|
||||||
|
monkeypatch,
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
model: str,
|
model,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if model in ENCODER_ONLY:
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
||||||
|
|
||||||
# The example_prompts has ending "\n", for example:
|
# The example_prompts has ending "\n", for example:
|
||||||
# "Write a short story about a robot that dreams for the first time.\n"
|
# "Write a short story about a robot that dreams for the first time.\n"
|
||||||
# sentence_transformers will strip the input texts, see:
|
# sentence_transformers will strip the input texts, see:
|
||||||
@ -33,7 +43,7 @@ def test_models(
|
|||||||
is_sentence_transformer=True) as hf_model:
|
is_sentence_transformer=True) as hf_model:
|
||||||
hf_outputs = hf_model.encode(example_prompts)
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype, max_model_len=None) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(example_prompts)
|
vllm_outputs = vllm_model.encode(example_prompts)
|
||||||
|
|
||||||
check_embeddings_close(
|
check_embeddings_close(
|
||||||
|
|||||||
@ -15,8 +15,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class AttentionType(Enum):
|
class AttentionType(Enum):
|
||||||
DECODER = auto() # Decoder attention between previous layer Q/K/V
|
DECODER = auto() # Decoder attention between previous layer Q/K/V
|
||||||
ENCODER = auto() # Encoder attention between previous layer Q/K/V
|
ENCODER = auto(
|
||||||
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
|
) # Encoder attention between previous layer Q/K/V for encoder-decoder
|
||||||
|
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
|
||||||
|
ENCODER_DECODER = auto(
|
||||||
|
) # Attention between dec. Q and enc. K/V for encoder-decoder
|
||||||
|
|
||||||
|
|
||||||
class AttentionBackend(ABC):
|
class AttentionBackend(ABC):
|
||||||
|
|||||||
@ -287,13 +287,15 @@ def _get_attn_bias(
|
|||||||
* Appropriate attention bias value given the attention type
|
* Appropriate attention bias value given the attention type
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if attn_type == AttentionType.DECODER:
|
if (attn_type == AttentionType.DECODER
|
||||||
|
or attn_type == AttentionType.ENCODER_ONLY):
|
||||||
return attn_metadata.attn_bias
|
return attn_metadata.attn_bias
|
||||||
elif attn_type == AttentionType.ENCODER:
|
elif attn_type == AttentionType.ENCODER:
|
||||||
return attn_metadata.encoder_attn_bias
|
return attn_metadata.encoder_attn_bias
|
||||||
else:
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
# attn_type == AttentionType.ENCODER_DECODER
|
|
||||||
return attn_metadata.cross_attn_bias
|
return attn_metadata.cross_attn_bias
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
def _set_attn_bias(
|
def _set_attn_bias(
|
||||||
@ -313,7 +315,8 @@ def _set_attn_bias(
|
|||||||
encoder/decoder cross-attention
|
encoder/decoder cross-attention
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if attn_type == AttentionType.DECODER:
|
if (attn_type == AttentionType.DECODER
|
||||||
|
or attn_type == AttentionType.ENCODER_ONLY):
|
||||||
attn_metadata.attn_bias = attn_bias
|
attn_metadata.attn_bias = attn_bias
|
||||||
elif attn_type == AttentionType.ENCODER:
|
elif attn_type == AttentionType.ENCODER:
|
||||||
attn_metadata.encoder_attn_bias = attn_bias
|
attn_metadata.encoder_attn_bias = attn_bias
|
||||||
@ -371,6 +374,12 @@ def _get_seq_len_block_table_args(
|
|||||||
# No block tables associated with encoder attention
|
# No block tables associated with encoder attention
|
||||||
return (attn_metadata.encoder_seq_lens_tensor,
|
return (attn_metadata.encoder_seq_lens_tensor,
|
||||||
attn_metadata.max_encoder_seq_len, None)
|
attn_metadata.max_encoder_seq_len, None)
|
||||||
|
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
assert is_prompt, "Should not have decode for encoder only model."
|
||||||
|
|
||||||
|
# No block tables associated with encoder attention
|
||||||
|
return (attn_metadata.seq_lens_tensor,
|
||||||
|
attn_metadata.max_prefill_seq_len, None)
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
@ -479,7 +488,10 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
* ENCODER: no KV caching; pass encoder sequence
|
* ENCODER: no KV caching; pass encoder sequence
|
||||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
|
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||||
|
Used for encoder branch of encoder-decoder models.
|
||||||
|
* ENCODER_ONLY: no kv_caching, uses the normal attention
|
||||||
|
attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||||
* ENCODER_DECODER: cross-attention behavior;
|
* ENCODER_DECODER: cross-attention behavior;
|
||||||
use cross-attention block table for caching KVs derived
|
use cross-attention block table for caching KVs derived
|
||||||
from encoder hidden states; since KV sequence lengths
|
from encoder hidden states; since KV sequence lengths
|
||||||
@ -509,6 +521,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
raise AttributeError("Encoder attention requires setting "
|
raise AttributeError("Encoder attention requires setting "
|
||||||
"encoder metadata attributes.")
|
"encoder metadata attributes.")
|
||||||
|
|
||||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||||
raise AttributeError("Encoder/decoder cross-attention "
|
raise AttributeError("Encoder/decoder cross-attention "
|
||||||
@ -609,6 +622,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
assert out.shape == output[:num_prefill_tokens].shape
|
assert out.shape == output[:num_prefill_tokens].shape
|
||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_tokens] = out
|
||||||
else:
|
else:
|
||||||
|
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||||
|
"Encoder-only models should not have prefix attention.")
|
||||||
|
|
||||||
assert prefill_meta.query_start_loc is not None
|
assert prefill_meta.query_start_loc is not None
|
||||||
assert prefill_meta.max_query_len is not None
|
assert prefill_meta.max_query_len is not None
|
||||||
@ -638,6 +653,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_tokens] = out
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||||
|
"Encoder-only models should not have decode metadata.")
|
||||||
|
|
||||||
(
|
(
|
||||||
seq_lens_arg,
|
seq_lens_arg,
|
||||||
@ -703,36 +720,60 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||||
self.num_queries_per_kv,
|
self.num_queries_per_kv,
|
||||||
value.shape[-1])
|
value.shape[-1])
|
||||||
|
|
||||||
# Set attention bias if not provided. This typically happens at
|
# Set attention bias if not provided. This typically happens at
|
||||||
# the very attention layer of every iteration.
|
# the very attention layer of every iteration.
|
||||||
# FIXME(woosuk): This is a hack.
|
# FIXME(woosuk): This is a hack.
|
||||||
attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
||||||
if attn_bias is None:
|
if attn_bias is None:
|
||||||
if self.alibi_slopes is None:
|
if self.alibi_slopes is None:
|
||||||
|
|
||||||
|
# Cross attention block of decoder branch of encoder-decoder
|
||||||
|
# model uses seq_lens for dec / encoder_seq_lens for enc
|
||||||
if (attn_type == AttentionType.ENCODER_DECODER):
|
if (attn_type == AttentionType.ENCODER_DECODER):
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
assert attn_metadata.encoder_seq_lens is not None
|
assert attn_metadata.encoder_seq_lens is not None
|
||||||
|
|
||||||
# Default enc/dec cross-attention mask is non-causal
|
# Cross-attention mask is non-causal
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
|
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
|
||||||
|
|
||||||
|
# Encoder branch of encoder-decoder model uses
|
||||||
|
# attn_metadata.encoder_seq_lens
|
||||||
elif attn_type == AttentionType.ENCODER:
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
|
||||||
assert attn_metadata.encoder_seq_lens is not None
|
assert attn_metadata.encoder_seq_lens is not None
|
||||||
|
|
||||||
# Default encoder self-attention mask is non-causal
|
# Encoder self-attention mask is non-causal
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
attn_metadata.encoder_seq_lens)
|
attn_metadata.encoder_seq_lens)
|
||||||
else:
|
|
||||||
|
# Self-attention block of encoder-only model just
|
||||||
|
# uses the seq_lens directly.
|
||||||
|
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
|
|
||||||
# Default decoder self-attention mask is causal
|
# Encoder self-attention mask is non-causal
|
||||||
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
|
attn_metadata.seq_lens)
|
||||||
|
|
||||||
|
# Self-attention block of decoder branch just
|
||||||
|
# uses the seq_lens directly
|
||||||
|
elif attn_type == AttentionType.DECODER:
|
||||||
|
assert attn_metadata.seq_lens is not None
|
||||||
|
|
||||||
|
# Decoder self-attention mask is causal
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||||
attn_metadata.seq_lens)
|
attn_metadata.seq_lens)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown AttentionType: %s", attn_type)
|
||||||
|
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
attn_bias = attn_bias.make_local_attention(
|
attn_bias = attn_bias.make_local_attention(
|
||||||
self.sliding_window)
|
self.sliding_window)
|
||||||
attn_bias = [attn_bias]
|
attn_bias = [attn_bias]
|
||||||
else:
|
else:
|
||||||
|
assert attn_type == AttentionType.DECODER
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
attn_bias = _make_alibi_bias(self.alibi_slopes,
|
attn_bias = _make_alibi_bias(self.alibi_slopes,
|
||||||
self.num_kv_heads, query.dtype,
|
self.num_kv_heads, query.dtype,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ class PoolingType(IntEnum):
|
|||||||
"""Enumeration for different types of pooling methods."""
|
"""Enumeration for different types of pooling methods."""
|
||||||
LAST = 0
|
LAST = 0
|
||||||
ALL = 1
|
ALL = 1
|
||||||
|
CLS = 2
|
||||||
|
|
||||||
|
|
||||||
class Pooler(nn.Module):
|
class Pooler(nn.Module):
|
||||||
@ -23,12 +24,13 @@ class Pooler(nn.Module):
|
|||||||
3. Returns structured results as `PoolerOutput`.
|
3. Returns structured results as `PoolerOutput`.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
|
pooling_type: The type of pooling to use (LAST, ALL, CLS).
|
||||||
normalize: Whether to normalize the pooled data.
|
normalize: Whether to normalize the pooled data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pooling_type: PoolingType, normalize: bool):
|
def __init__(self, pooling_type: PoolingType, normalize: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.pooling_type = pooling_type
|
self.pooling_type = pooling_type
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
|
|
||||||
@ -38,10 +40,16 @@ class Pooler(nn.Module):
|
|||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
"""Pools specific information from hidden states based on metadata."""
|
"""Pools specific information from hidden states based on metadata."""
|
||||||
|
|
||||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
prompt_lens = PoolingTensors.from_pooling_metadata(
|
||||||
pooling_metadata, hidden_states.device).prompt_lens
|
pooling_metadata, hidden_states.device).prompt_lens
|
||||||
|
|
||||||
if self.pooling_type == PoolingType.LAST:
|
if self.pooling_type is PoolingType.CLS:
|
||||||
|
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||||
|
first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
|
||||||
|
dim=0)[:-1]
|
||||||
|
pooled_data = hidden_states[first_token_flat_indices]
|
||||||
|
elif self.pooling_type == PoolingType.LAST:
|
||||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||||
pooled_data = hidden_states[last_token_flat_indices]
|
pooled_data = hidden_states[last_token_flat_indices]
|
||||||
elif self.pooling_type == PoolingType.ALL:
|
elif self.pooling_type == PoolingType.ALL:
|
||||||
|
|||||||
419
vllm/model_executor/models/bert.py
Normal file
419
vllm/model_executor/models/bert.py
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import BertConfig
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
|
from vllm.attention.backends.xformers import XFormersImpl
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class BertEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: BertConfig):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.size = config.hidden_size
|
||||||
|
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
||||||
|
config.hidden_size)
|
||||||
|
self.position_embeddings = VocabParallelEmbedding(
|
||||||
|
config.max_position_embeddings, config.hidden_size)
|
||||||
|
self.token_type_embeddings = VocabParallelEmbedding(
|
||||||
|
config.type_vocab_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self.position_ids = nn.Parameter(
|
||||||
|
torch.empty((1, config.max_position_embeddings)), )
|
||||||
|
|
||||||
|
self.position_embedding_type = config.position_embedding_type
|
||||||
|
if self.position_embedding_type != "absolute":
|
||||||
|
raise ValueError("Only 'absolute' position_embedding_type" +
|
||||||
|
" is supported")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
|
||||||
|
# Input embeddings.
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
# Position embeddings.
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
|
||||||
|
# Token type embeddings. (TODO: move off hotpath?)
|
||||||
|
token_type_embeddings = self.token_type_embeddings(
|
||||||
|
torch.zeros(input_shape,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=inputs_embeds.device))
|
||||||
|
|
||||||
|
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BertEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: BertConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = nn.ModuleList([
|
||||||
|
BertLayer(config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.layer.{layer_idx}")
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
for i in range(len(self.layer)):
|
||||||
|
layer = self.layer[i]
|
||||||
|
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: BertConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attention = BertAttention(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
num_attention_heads=config.num_attention_heads,
|
||||||
|
layer_norm_eps=config.layer_norm_eps,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attention")
|
||||||
|
|
||||||
|
self.intermediate = BertIntermediate(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.intermediate")
|
||||||
|
|
||||||
|
self.output = BertOutput(hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
layer_norm_eps=config.layer_norm_eps,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.output")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: Optional[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
):
|
||||||
|
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
|
||||||
|
intermediate_output = self.intermediate(attn_output)
|
||||||
|
output = self.output(intermediate_output, attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BertAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
layer_norm_eps: float,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.self = BertSelfAttention(hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.output")
|
||||||
|
|
||||||
|
self.output = BertSelfOutput(hidden_size=hidden_size,
|
||||||
|
layer_norm_eps=layer_norm_eps,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.output")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
self_output = self.self(hidden_states, kv_cache, attn_metadata)
|
||||||
|
return self.output(self_output, hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.total_num_heads = num_attention_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = self.total_num_heads
|
||||||
|
self.head_dim = self.hidden_size // self.total_num_heads
|
||||||
|
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||||
|
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.total_num_heads,
|
||||||
|
total_num_kv_heads=self.total_num_kv_heads,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj")
|
||||||
|
|
||||||
|
self.attn = Attention(num_heads=self.num_heads,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
scale=self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
|
if not isinstance(self.attn.impl, XFormersImpl):
|
||||||
|
raise ValueError(
|
||||||
|
"Encoder-only models currently require XFORMERS attention "
|
||||||
|
"backend. Set VLLM_ATTENTION_BACKEND=XFORMERS to use BERT.")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
output = self.attn(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=AttentionType.ENCODER_ONLY)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfOutput(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int,
|
||||||
|
layer_norm_eps: float,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = RowParallelLinear(input_size=hidden_size,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense")
|
||||||
|
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.dense(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertIntermediate(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = ColumnParallelLinear(input_size=hidden_size,
|
||||||
|
output_size=intermediate_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense")
|
||||||
|
self.intermediate_act_fn = get_act_fn(hidden_act)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.dense(hidden_states)
|
||||||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertOutput(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
layer_norm_eps: float,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dense = RowParallelLinear(input_size=intermediate_size,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense")
|
||||||
|
|
||||||
|
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.dense(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: BertConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.embeddings = BertEmbedding(config)
|
||||||
|
self.encoder = BertEncoder(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.encoder")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.embeddings(input_ids=input_ids,
|
||||||
|
position_ids=position_ids)
|
||||||
|
|
||||||
|
return self.encoder(hidden_states, kv_caches, attn_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "query", "q"),
|
||||||
|
("qkv_proj", "key", "k"),
|
||||||
|
("qkv_proj", "value", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "pooler" in name:
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class BertEmbeddingModel(nn.Module):
|
||||||
|
"""A model that uses Bert to provide embedding functionalities.
|
||||||
|
|
||||||
|
This class encapsulates the BertModel and provides an interface for
|
||||||
|
embedding operations and customized pooling functions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: An instance of BertModel used for forward operations.
|
||||||
|
_pooler: An instance of Pooler used for pooling operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BertConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = BertModel(config, cache_config, quant_config)
|
||||||
|
self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model(input_ids=input_ids,
|
||||||
|
position_ids=positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
self.model.load_weights(weights)
|
||||||
@ -87,6 +87,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
|
|
||||||
_EMBEDDING_MODELS = {
|
_EMBEDDING_MODELS = {
|
||||||
# [Text-only]
|
# [Text-only]
|
||||||
|
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user