mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 08:11:50 +08:00
[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
db100c5cde
commit
eebad39f26
@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
|
||||
# List of support backends for encoder/decoder models
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||
@ -594,6 +596,7 @@ def _run_encoder_attention_test(
|
||||
encoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
vllm_config: VllmConfig,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder attention.
|
||||
@ -623,7 +626,7 @@ def _run_encoder_attention_test(
|
||||
attn_type = AttentionType.ENCODER
|
||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
with set_forward_context(attn_metadata):
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
vllm_config: VllmConfig,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run decoder self-attention test.
|
||||
@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
with set_forward_context(attn_metadata):
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
cross_test_params: Optional[PhaseTestParameters],
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
vllm_config: VllmConfig,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder/decoder cross-attention test.
|
||||
@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||
with set_forward_context(attn_metadata):
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
@ -839,7 +844,9 @@ def test_encoder_only(
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
@ -863,7 +870,8 @@ def test_encoder_only(
|
||||
test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt))
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config))
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||
@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||
@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
|
||||
test_rsrcs,
|
||||
prephase_dec_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
|
||||
prephase_dec_test_params,
|
||||
prephase_cross_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
|
||||
test_rsrcs,
|
||||
decphase_dec_test_params,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
|
||||
decphase_dec_test_params,
|
||||
None,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
test_pt=test_pt,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum, auto
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
||||
Tuple, Type, TypeVar)
|
||||
|
||||
@ -15,13 +14,19 @@ if TYPE_CHECKING:
|
||||
ModelRunnerInputBuilderBase)
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
DECODER = auto() # Decoder attention between previous layer Q/K/V
|
||||
ENCODER = auto(
|
||||
) # Encoder attention between previous layer Q/K/V for encoder-decoder
|
||||
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
|
||||
ENCODER_DECODER = auto(
|
||||
) # Attention between dec. Q and enc. K/V for encoder-decoder
|
||||
class AttentionType:
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
# Decoder attention between previous layer Q/K/V
|
||||
DECODER = "decoder"
|
||||
# Encoder attention between previous layer Q/K/V for encoder-decoder
|
||||
ENCODER = "encoder"
|
||||
# Encoder attention between previous layer Q/K/V
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
# Attention between dec. Q and enc. K/V for encoder-decoder
|
||||
ENCODER_DECODER = "encoder_decoder"
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
attn_metadata: T,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
|
||||
@ -16,10 +16,8 @@ from vllm.attention.backends.utils import (
|
||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
make_tensor_with_pad)
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@ -668,23 +666,174 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
output = torch.ops.vllm.unified_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
attn_type.value,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
)
|
||||
num_heads: int = self.num_heads
|
||||
head_size: int = self.head_size
|
||||
num_kv_heads: int = self.num_kv_heads
|
||||
kv_cache_dtype: str = self.kv_cache_dtype
|
||||
softmax_scale: float = self.scale
|
||||
window_size = self.sliding_window
|
||||
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
||||
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
if (key is not None) and (value is not None):
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the
|
||||
# KV cache is already populated with the cross-attention
|
||||
# tensor. Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=q_seq_start_loc,
|
||||
cu_seqlens_k=k_seq_start_loc,
|
||||
max_seqlen_q=q_seq_len,
|
||||
max_seqlen_k=k_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=_get_causal_option(attn_type),
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
prefill_output = flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||
# because different queries might have different lengths.
|
||||
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
# use only for actual varlen decoding
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1"
|
||||
)
|
||||
decode_output = flash_attn_varlen_func(
|
||||
q=decode_query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_decode_query_len,
|
||||
cu_seqlens_k=decode_meta.seq_start_loc,
|
||||
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
block_table=decode_meta.block_tables,
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
(
|
||||
seq_lens_arg,
|
||||
_,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
decode_output = flash_attn_with_kvcache(
|
||||
q=decode_query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=block_tables_arg,
|
||||
cache_seqlens=seq_lens_arg,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
).squeeze(1)
|
||||
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
return decode_output.view(num_decode_query_tokens, hidden_size)
|
||||
if decode_output is None:
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
||||
|
||||
assert decode_meta is not None
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
return output
|
||||
|
||||
@ -692,7 +841,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
def _get_query_key_seq_metadata(
|
||||
attn_metadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
"""
|
||||
Returns sequence metadata for key and query based on the specified
|
||||
@ -754,7 +903,7 @@ def _get_query_key_seq_metadata(
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _get_causal_option(attn_type: AttentionType) -> bool:
|
||||
def _get_causal_option(attn_type: str) -> bool:
|
||||
"""
|
||||
Determine whether the given attention type is suitable for causal
|
||||
attention mechanisms.
|
||||
@ -770,220 +919,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool:
|
||||
return not (attn_type == AttentionType.ENCODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY
|
||||
or attn_type == AttentionType.ENCODER_DECODER)
|
||||
|
||||
|
||||
def unified_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Convert integer attn_type to enum
|
||||
try:
|
||||
attn_type = AttentionType(attn_type_int_val)
|
||||
except ValueError as err:
|
||||
raise AttributeError(
|
||||
f"Invalid attention type {str(attn_type_int_val)}") from err
|
||||
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
if (key is not None) and (value is not None):
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the KV
|
||||
# cache is already populated with the cross-attention tensor.
|
||||
# Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=q_seq_start_loc,
|
||||
cu_seqlens_k=k_seq_start_loc,
|
||||
max_seqlen_q=q_seq_len,
|
||||
max_seqlen_k=k_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=_get_causal_option(attn_type),
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
prefill_output = flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||
# because different queries might have different lengths.
|
||||
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
# use only for actual varlen decoding
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1")
|
||||
decode_output = flash_attn_varlen_func(
|
||||
q=decode_query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_decode_query_len,
|
||||
cu_seqlens_k=decode_meta.seq_start_loc,
|
||||
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
block_table=decode_meta.block_tables,
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
(
|
||||
seq_lens_arg,
|
||||
_,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
decode_output = flash_attn_with_kvcache(
|
||||
q=decode_query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=block_tables_arg,
|
||||
cache_seqlens=seq_lens_arg,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
).squeeze(1)
|
||||
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
return decode_output.view(num_decode_query_tokens, hidden_size)
|
||||
if decode_output is None:
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
||||
|
||||
assert decode_meta is not None
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def unified_flash_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention",
|
||||
op_func=unified_flash_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_attention_fake,
|
||||
)
|
||||
|
||||
@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
get_kv_cache_torch_dtype, make_tensor_with_pad)
|
||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||
make_tensor_with_pad)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
attn_metadata: FlashInferMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
@ -782,174 +781,117 @@ class FlashInferImpl(AttentionImpl):
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
|
||||
return torch.ops.vllm.unified_flash_infer(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
)
|
||||
num_heads: int = self.num_heads
|
||||
head_size: int = self.head_size
|
||||
num_kv_heads: int = self.num_kv_heads
|
||||
kv_cache_dtype: str = self.kv_cache_dtype
|
||||
softmax_scale: float = self.scale
|
||||
window_size = self.sliding_window
|
||||
alibi_slopes = self.alibi_slopes
|
||||
logits_soft_cap = self.logits_soft_cap
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
def unified_flash_infer(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashInferMetadata)
|
||||
attn_metadata: FlashInferMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
# Use the same reshape and cache kernel as flash attention.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||
query = query.contiguous() # Flashinfer requires query to be contiguous
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
# QKV for prefill.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
query = query[:num_prefill_tokens]
|
||||
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# We will use flash attention for prefill
|
||||
# when kv_cache is not provided.
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache.numel() == 0:
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
if kv_cache.numel() > 0:
|
||||
# Use the same reshape and cache kernel as flash attention.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||
query,
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||
query = query.contiguous(
|
||||
) # Flashinfer requires query to be contiguous
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
# QKV for prefill.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
query = query[:num_prefill_tokens]
|
||||
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# We will use flash attention for prefill
|
||||
# when kv_cache is not provided.
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache.numel() == 0:
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
else:
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
causal=True,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
window_left=window_left)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_wrapper is not None
|
||||
decode_output = decode_meta.decode_wrapper.forward(
|
||||
decode_query,
|
||||
kv_cache,
|
||||
sm_scale=softmax_scale,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
causal=True,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
window_left=window_left)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||
decode_query,
|
||||
kv_cache,
|
||||
sm_scale=softmax_scale,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
window_left=window_left)
|
||||
|
||||
if prefill_output is None and decode_output is not None:
|
||||
# Decode only batch.
|
||||
output, num_tokens = decode_output, num_decode_tokens
|
||||
elif decode_output is None and prefill_output is not None:
|
||||
# Prefill only batch.
|
||||
output, num_tokens = prefill_output, num_prefill_tokens
|
||||
else:
|
||||
# Chunked prefill batch does not work with speculative decoding in
|
||||
# FlashInfer backend, so the query length for decode should be 1.
|
||||
assert prefill_output is not None
|
||||
assert decode_output is not None
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_query_len == 1
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def unified_flash_infer_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_infer",
|
||||
op_func=unified_flash_infer,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_infer_fake,
|
||||
)
|
||||
if prefill_output is None and decode_output is not None:
|
||||
# Decode only batch.
|
||||
output, num_tokens = decode_output, num_decode_tokens
|
||||
elif decode_output is None and prefill_output is not None:
|
||||
# Prefill only batch.
|
||||
output, num_tokens = prefill_output, num_prefill_tokens
|
||||
else:
|
||||
# Chunked prefill batch does not work with speculative decoding in
|
||||
# FlashInfer backend, so the query length for decode should be 1.
|
||||
assert prefill_output is not None
|
||||
assert decode_output is not None
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_query_len == 1
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
attn_metadata: HPUAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
|
||||
@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata: PallasMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
|
||||
@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
|
||||
@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
def get_seq_lens(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
):
|
||||
'''
|
||||
Extract appropriate sequence lengths from attention metadata
|
||||
@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
def get_attn_bias(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
def set_attn_bias(
|
||||
self,
|
||||
attn_bias: List[torch.Tensor],
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
|
||||
@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata):
|
||||
def get_seq_len_block_table_args(
|
||||
attn_metadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
@ -529,7 +529,7 @@ def get_seq_len_block_table_args(
|
||||
|
||||
def get_num_prefill_decode_query_kv_tokens(
|
||||
attn_metadata,
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Calculate the number of prefill and decode tokens for query, key/value
|
||||
|
||||
@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
def _get_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
) -> Optional[AttentionBias]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
@ -314,7 +314,7 @@ def _get_attn_bias(
|
||||
def _set_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_bias: List[Optional[AttentionBias]],
|
||||
attn_type: AttentionType,
|
||||
attn_type: str,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
attn_metadata: "XFormersMetadata",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
|
||||
@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.plugins import get_current_vllm_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -86,6 +91,18 @@ class Attention(nn.Module):
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap)
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = envs.VLLM_USE_V1 or not (
|
||||
current_platform.is_cuda_alike() or current_platform.is_cpu())
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -93,17 +110,22 @@ class Attention(nn.Module):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.impl.forward(query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
attn_type=attn_type)
|
||||
if self.use_direct_call:
|
||||
return self.impl.forward(query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
attn_type=attn_type)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(query, key, value,
|
||||
kv_cache, attn_type,
|
||||
self.layer_name)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
@ -112,3 +134,44 @@ class Attention(nn.Module):
|
||||
s += f", scale={self.impl.scale}" # type: ignore
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_type: str,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.dynamic_forward_context
|
||||
self = forward_context.static_forward_context[layer_name]
|
||||
return self.impl.forward(query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_type: str,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
|
||||
backend: str = ""
|
||||
custom_ops: List[str] = Field(default_factory=list)
|
||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||
"vllm.unified_flash_attention",
|
||||
"vllm.unified_flash_infer",
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_v1_flash_attention",
|
||||
])
|
||||
|
||||
@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||
|
||||
# Per-model forward context
|
||||
# Mainly used to store attention cls
|
||||
# Map from layer name to the attention cls
|
||||
static_forward_context: Dict[str, Any] = PrivateAttr
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
|
||||
|
||||
self.enabled_custom_ops = Counter()
|
||||
self.disabled_custom_ops = Counter()
|
||||
self.static_forward_context = {}
|
||||
|
||||
def init_backend(self) -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
|
||||
@ -1,21 +1,38 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
_forward_context: Any = None
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
def get_forward_context() -> Any:
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
static_forward_context: Dict[str, Any]
|
||||
# TODO: extend to support per-layer dynamic forward context
|
||||
dynamic_forward_context: Any
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
|
||||
|
||||
def get_forward_context() -> ForwardContext:
|
||||
"""Get the current forward context."""
|
||||
assert _forward_context is not None, (
|
||||
"Forward context is not set. "
|
||||
"Please use `set_forward_context` to set the forward context.")
|
||||
return _forward_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(context: Any):
|
||||
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc."""
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = context
|
||||
_forward_context = ForwardContext(
|
||||
static_forward_context=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
dynamic_forward_context=context)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
||||
@ -223,6 +223,7 @@ class ArcticAttention(nn.Module):
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -274,7 +275,8 @@ class ArcticAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module):
|
||||
self.self_attn = ArcticAttention(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.block_sparse_moe = ArcticMoE(
|
||||
config,
|
||||
layer_id=layer_idx,
|
||||
@ -380,8 +384,11 @@ class ArcticModel(nn.Module):
|
||||
org_num_embeddings=self.vocab_size)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: ArcticDecoderLayer(config, int(
|
||||
prefix.split(".")[-1]), cache_config, quant_config),
|
||||
lambda prefix: ArcticDecoderLayer(config,
|
||||
int(prefix.split(".")[-1]),
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module):
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
else:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
|
||||
cache_config, quant_config),
|
||||
lambda prefix: BaiChuanDecoderLayer(config,
|
||||
position_embedding,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module):
|
||||
config: Optional[BartConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module):
|
||||
config: Optional[BartConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module):
|
||||
config: Optional[BartConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module):
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module):
|
||||
num_heads=config.encoder_attention_heads,
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module):
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module):
|
||||
num_heads=config.decoder_attention_heads,
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module):
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
config=config,
|
||||
prefix=f"{prefix}.encoder_attn",
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
@ -578,7 +591,8 @@ class BartEncoder(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
embed_tokens: Optional[nn.Embedding] = None):
|
||||
embed_tokens: Optional[nn.Embedding] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.cache_config = cache_config
|
||||
@ -599,9 +613,13 @@ class BartEncoder(nn.Module):
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[BartEncoderLayer(config,cache_config,quant_config) \
|
||||
for _ in range(config.encoder_layers)])
|
||||
self.layers = nn.ModuleList([
|
||||
BartEncoderLayer(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(config.encoder_layers)
|
||||
])
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
||||
@ -661,6 +679,7 @@ class BartDecoder(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
embed_tokens: Optional[nn.Embedding] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.cache_config = cache_config
|
||||
@ -683,8 +702,9 @@ class BartDecoder(nn.Module):
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[BartDecoderLayer(config,cache_config,quant_config) \
|
||||
for _ in range(config.decoder_layers)])
|
||||
[BartDecoderLayer(config,cache_config,quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}") \
|
||||
for layer_idx in range(config.decoder_layers)])
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
@ -759,10 +779,12 @@ class BartModel(nn.Module):
|
||||
|
||||
self.encoder = BartEncoder(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.decoder = BartDecoder(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.decoder")
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
|
||||
@ -78,6 +78,7 @@ class BloomAttention(nn.Module):
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -116,7 +117,8 @@ class BloomAttention(nn.Module):
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -168,14 +170,17 @@ class BloomBlock(nn.Module):
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.self_attention = BloomAttention(config, cache_config,
|
||||
quant_config)
|
||||
self.self_attention = BloomAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attention")
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = BloomMLP(config, quant_config)
|
||||
@ -242,7 +247,8 @@ class BloomModel(nn.Module):
|
||||
# Transformer blocks
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: BloomBlock(config, cache_config, quant_config),
|
||||
lambda prefix: BloomBlock(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.h")
|
||||
|
||||
# Final Layer Norm
|
||||
|
||||
@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor,
|
||||
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
config: ChameleonConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = ChameleonMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
config: ChameleonConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = ChameleonMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -855,7 +861,8 @@ class ChameleonModel(nn.Module):
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: decoder_layer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
|
||||
@ -230,6 +230,7 @@ class GLMAttention(nn.Module):
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -285,7 +286,8 @@ class GLMAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -364,6 +366,7 @@ class GLMBlock(nn.Module):
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
@ -377,7 +380,10 @@ class GLMBlock(nn.Module):
|
||||
eps=config.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config, cache_config, quant_config)
|
||||
self.self_attention = GLMAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attention")
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
@ -446,7 +452,8 @@ class GLMTransformer(nn.Module):
|
||||
# Transformer layers.
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.num_layers,
|
||||
lambda prefix: GLMBlock(config, cache_config, quant_config),
|
||||
lambda prefix: GLMBlock(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
@ -500,16 +507,22 @@ class ChatGLMModel(nn.Module):
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
||||
self.encoder = GLMTransformer(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output_layer")
|
||||
|
||||
vision_config_flag = getattr(config, 'vision_config', None)
|
||||
if vision_config_flag is not None:
|
||||
self.vision_config = Namespace(**config.vision_config)
|
||||
self.vision = EVA2CLIPModel(self.config, quant_config)
|
||||
self.vision = EVA2CLIPModel(self.config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision")
|
||||
else:
|
||||
self.vision = None
|
||||
|
||||
|
||||
@ -120,6 +120,7 @@ class CohereAttention(nn.Module):
|
||||
config: CohereConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -175,7 +176,8 @@ class CohereAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||
self.head_dim),
|
||||
@ -215,13 +217,15 @@ class CohereDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: CohereConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = CohereAttention(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
|
||||
self.mlp = CohereMLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||
@ -271,8 +275,8 @@ class CohereModel(nn.Module):
|
||||
config.hidden_size)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CohereDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
lambda prefix: CohereDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -154,6 +154,7 @@ class DbrxAttention(nn.Module):
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@ -208,7 +209,8 @@ class DbrxAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -234,10 +236,14 @@ class DbrxFusedNormAttention(nn.Module):
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.attn = DbrxAttention(config, cache_config, quant_config)
|
||||
self.attn = DbrxAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.norm_1 = nn.LayerNorm(self.d_model)
|
||||
self.norm_2 = nn.LayerNorm(self.d_model)
|
||||
|
||||
@ -269,10 +275,14 @@ class DbrxBlock(nn.Module):
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
|
||||
quant_config)
|
||||
self.norm_attn_norm = DbrxFusedNormAttention(
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.norm_attn_norm")
|
||||
self.ffn = DbrxMoE(config, quant_config)
|
||||
|
||||
def forward(
|
||||
@ -308,7 +318,8 @@ class DbrxModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.blocks = make_layers(
|
||||
config.n_layers,
|
||||
lambda prefix: DbrxBlock(config, cache_config, quant_config),
|
||||
lambda prefix: DbrxBlock(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.blocks",
|
||||
)
|
||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||
|
||||
@ -184,6 +184,7 @@ class DeepseekAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -236,7 +237,8 @@ class DeepseekAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -261,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -277,6 +280,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
@ -346,7 +350,8 @@ class DeepseekModel(nn.Module):
|
||||
lambda prefix: DeepseekDecoderLayer(config,
|
||||
int(prefix.split(".")[-1]),
|
||||
cache_config,
|
||||
quant_config=quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -268,7 +268,8 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -174,6 +174,7 @@ class ExaoneAttention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -219,7 +220,7 @@ class ExaoneBlockAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
cache_config=cache_config,
|
||||
prefix=prefix,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -84,6 +84,7 @@ class FalconAttention(nn.Module):
|
||||
config: FalconConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -158,7 +159,8 @@ class FalconAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
elif self.use_alibi:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
@ -171,14 +173,16 @@ class FalconAttention(nn.Module):
|
||||
self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
else:
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -241,12 +245,16 @@ class FalconDecoderLayer(nn.Module):
|
||||
config: FalconConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.self_attention = FalconAttention(config, cache_config,
|
||||
quant_config)
|
||||
self.self_attention = FalconAttention(
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attention")
|
||||
self.mlp = FalconMLP(config, quant_config)
|
||||
self.config = config
|
||||
|
||||
@ -357,8 +365,8 @@ class FalconModel(nn.Module):
|
||||
# Transformer blocks
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: FalconDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
lambda prefix: FalconDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.h")
|
||||
|
||||
# Final Layer Norm
|
||||
|
||||
@ -35,10 +35,12 @@ class Florence2LanguageModel(nn.Module):
|
||||
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
|
||||
self.encoder = BartEncoder(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.decoder = BartDecoder(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.decoder")
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.encoder.embed_tokens.weight = self.shared.weight
|
||||
@ -99,7 +101,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
|
||||
self.config = config
|
||||
self.model = Florence2LanguageModel(vllm_config=vllm_config,
|
||||
prefix=prefix)
|
||||
prefix=f"{prefix}.model")
|
||||
embed_scale = math.sqrt(
|
||||
config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
@ -198,7 +200,7 @@ class Florence2ForConditionalGeneration(nn.Module):
|
||||
# TODO(Isotr0py): Add vision backbone
|
||||
self.language_model = Florence2LanguageForConditionalGeneration(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=prefix,
|
||||
prefix=f"{prefix}.language_model",
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -174,7 +174,8 @@ class GemmaAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -95,7 +95,8 @@ class Gemma2Attention(nn.Module):
|
||||
rope_theta: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
attn_logits_soft_cap: Optional[float] = None) -> None:
|
||||
attn_logits_soft_cap: Optional[float] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
@ -154,7 +155,8 @@ class Gemma2Attention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=attn_logits_soft_cap)
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -179,6 +181,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
config: Gemma2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -194,6 +197,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_logits_soft_cap=config.attn_logit_softcapping,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.mlp = Gemma2MLP(
|
||||
@ -257,8 +261,11 @@ class Gemma2Model(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
|
||||
-1]), config, cache_config, quant_config),
|
||||
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
@ -56,6 +56,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -135,11 +136,14 @@ class TransformerLayer(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.input_layernorm = LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = Attention(config, quant_config=quant_config)
|
||||
self.attention = Attention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.mlp = MLP(config, quant_config=quant_config)
|
||||
self.post_attention_layernorm = LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
@ -161,11 +165,14 @@ class Transformer(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerLayer(config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
TransformerLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -252,12 +259,14 @@ class EVA2CLIPModel(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
vision_config = Namespace(**config.vision_config)
|
||||
self.patch_embedding = PatchEmbedding(vision_config)
|
||||
self.transformer = Transformer(vision_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.transformer")
|
||||
self.linear_proj = GLU(config,
|
||||
in_features=config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -84,7 +84,8 @@ class GPT2Attention(nn.Module):
|
||||
self.head_dim,
|
||||
scale=self.scale,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -52,6 +52,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -92,7 +93,8 @@ class GPTBigCodeAttention(nn.Module):
|
||||
scale=self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -151,6 +153,7 @@ class GPTBigCodeBlock(nn.Module):
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -158,7 +161,10 @@ class GPTBigCodeBlock(nn.Module):
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
|
||||
self.attn = GPTBigCodeAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||
|
||||
@ -210,7 +216,8 @@ class GPTBigCodeModel(nn.Module):
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
|
||||
lambda prefix: GPTBigCodeBlock(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.h",
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
@ -53,6 +53,7 @@ class GPTJAttention(nn.Module):
|
||||
config: GPTJConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
@ -94,7 +95,8 @@ class GPTJAttention(nn.Module):
|
||||
self.head_size,
|
||||
scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -147,12 +149,16 @@ class GPTJBlock(nn.Module):
|
||||
config: GPTJConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = (4 * config.n_embd
|
||||
if config.n_inner is None else config.n_inner)
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTJAttention(config, cache_config, quant_config)
|
||||
self.attn = GPTJAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
||||
|
||||
def forward(
|
||||
@ -193,7 +199,8 @@ class GPTJModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.n_layer,
|
||||
lambda prefix: GPTJBlock(config, cache_config, quant_config),
|
||||
lambda prefix: GPTJBlock(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.h",
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
@ -52,6 +52,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
config: GPTNeoXConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
@ -94,7 +95,8 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.head_size,
|
||||
scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -145,6 +147,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
config: GPTNeoXConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
@ -152,7 +155,10 @@ class GPTNeoXLayer(nn.Module):
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
|
||||
self.attention = GPTNeoXAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.mlp = GPTNeoXMLP(config, quant_config)
|
||||
|
||||
def forward(
|
||||
@ -205,7 +211,8 @@ class GPTNeoXModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
|
||||
lambda prefix: GPTNeoXLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
|
||||
@ -161,7 +161,8 @@ class GraniteAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -164,7 +164,8 @@ class GraniteMoeAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -250,7 +250,12 @@ class InternLMDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class InternLM2Model(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -266,7 +271,7 @@ class InternLM2Model(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: InternLMDecoderLayer(
|
||||
lambda prefix: layer_type(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -316,14 +321,18 @@ class InternLM2Model(nn.Module):
|
||||
|
||||
class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
model_type: Type[InternLM2Model] = InternLM2Model):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = InternLM2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = model_type(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.output = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
|
||||
@ -14,8 +14,6 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention,
|
||||
InternLM2MLP, InternLM2Model)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import make_layers, maybe_prefix
|
||||
|
||||
|
||||
class InternLM2VEDecoderLayer(nn.Module):
|
||||
|
||||
@ -105,17 +103,9 @@ class InternLM2VEDecoderLayer(nn.Module):
|
||||
class InternLM2VEModel(InternLM2Model):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: InternLM2VEDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=InternLM2VEDecoderLayer)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -159,7 +149,6 @@ class InternLM2VEModel(InternLM2Model):
|
||||
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
self.model = InternLM2VEModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
model_type=InternLM2VEModel)
|
||||
|
||||
@ -76,6 +76,7 @@ class JAISAttention(nn.Module):
|
||||
config: JAISConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -114,7 +115,8 @@ class JAISAttention(nn.Module):
|
||||
scale=self.scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -178,6 +180,7 @@ class JAISBlock(nn.Module):
|
||||
config: JAISConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -185,7 +188,10 @@ class JAISBlock(nn.Module):
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = JAISAttention(config, cache_config, quant_config)
|
||||
self.attn = JAISAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
||||
|
||||
@ -241,7 +247,8 @@ class JAISModel(nn.Module):
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: JAISBlock(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.h",
|
||||
)
|
||||
|
||||
|
||||
@ -102,7 +102,8 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
config: JambaConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.mamba = MambaMixer(hidden_size= config.hidden_size,
|
||||
@ -157,6 +158,7 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -198,6 +200,7 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
@ -287,7 +290,8 @@ class JambaModel(nn.Module):
|
||||
layer_class(config,
|
||||
layer_idx=i,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{i}"))
|
||||
self.layers = nn.ModuleList(decoder_layers)
|
||||
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
@ -174,6 +174,7 @@ class LlamaAttention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -192,6 +192,7 @@ class MiniCPMAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -246,7 +247,8 @@ class MiniCPMAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -273,6 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -283,6 +286,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.max_position_embeddings = getattr(config,
|
||||
"max_position_embeddings", 8192)
|
||||
self.prefix = prefix
|
||||
self._init_attn_block()
|
||||
self._init_ffn_block()
|
||||
|
||||
@ -298,6 +302,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=f"{self.prefix}.self_attn",
|
||||
)
|
||||
|
||||
def _init_ffn_block(self):
|
||||
@ -388,8 +393,8 @@ class MiniCPMModel(nn.Module):
|
||||
):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
lambda prefix: MiniCPMDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -60,6 +60,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -119,7 +120,8 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -195,6 +197,7 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=f"{self.prefix}.self_attn",
|
||||
)
|
||||
|
||||
|
||||
@ -209,8 +212,8 @@ class MiniCPM3Model(MiniCPMModel):
|
||||
):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
lambda prefix: MiniCPM3DecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
|
||||
|
||||
@ -166,7 +166,8 @@ class MixtralAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -170,6 +170,7 @@ class MixtralAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -219,7 +220,8 @@ class MixtralAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -243,6 +245,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
config: MixtralConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -255,7 +258,9 @@ class MixtralDecoderLayer(nn.Module):
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.block_sparse_moe = MixtralMoE(config=config,
|
||||
quant_config=quant_config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -311,7 +316,8 @@ class MixtralModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MixtralDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config),
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -370,6 +370,7 @@ class MolmoAttention(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -427,7 +428,8 @@ class MolmoAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
# Attention output projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@ -517,10 +519,14 @@ class MolmoDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.self_attn = MolmoAttention(config, cache_config, quant_config)
|
||||
self.self_attn = MolmoAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
|
||||
# MLP block.
|
||||
self.mlp = MolmoMLP(config, quant_config=quant_config)
|
||||
@ -738,7 +744,8 @@ class MolmoModel(nn.Module):
|
||||
else MolmoDecoderLayer
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: decoder_layer(config, cache_config, quant_config),
|
||||
lambda prefix: decoder_layer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
|
||||
@ -50,6 +50,7 @@ class MPTAttention(nn.Module):
|
||||
config: MPTConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@ -115,7 +116,8 @@ class MPTAttention(nn.Module):
|
||||
alibi_slopes=alibi_slopes,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -176,11 +178,15 @@ class MPTBlock(nn.Module):
|
||||
config: MPTConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
self.norm_1 = nn.LayerNorm(hidden_size)
|
||||
self.attn = MPTAttention(config, cache_config, quant_config)
|
||||
self.attn = MPTAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||
self.ffn = MPTMLP(config, quant_config)
|
||||
|
||||
@ -224,7 +230,8 @@ class MPTModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.blocks = make_layers(
|
||||
config.n_layers,
|
||||
lambda prefix: MPTBlock(config, cache_config, quant_config),
|
||||
lambda prefix: MPTBlock(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.blocks")
|
||||
self.norm_f = nn.LayerNorm(config.d_model)
|
||||
if config.no_bias:
|
||||
|
||||
@ -195,7 +195,8 @@ class NemotronAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -62,6 +62,7 @@ class OlmoAttention(nn.Module):
|
||||
config: OlmoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -101,7 +102,8 @@ class OlmoAttention(nn.Module):
|
||||
self.head_dim,
|
||||
scale=self.scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
# Attention output projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.self_attn = OlmoAttention(config, cache_config, quant_config)
|
||||
self.self_attn = OlmoAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
|
||||
# MLP block.
|
||||
self.mlp = OlmoMLP(config, quant_config)
|
||||
@ -238,8 +244,8 @@ class OlmoModel(nn.Module):
|
||||
config.hidden_size)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config
|
||||
),
|
||||
lambda prefix: OlmoDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = nn.LayerNorm(config.hidden_size,
|
||||
elementwise_affine=False,
|
||||
|
||||
@ -102,6 +102,7 @@ class OlmoeAttention(nn.Module):
|
||||
max_position_embeddings: int = 4096,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -156,7 +157,8 @@ class OlmoeAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -182,6 +184,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -199,6 +202,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
self.mlp = OlmoeMoE(
|
||||
@ -260,8 +264,11 @@ class OlmoeModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: OlmoeDecoderLayer(config, int(
|
||||
prefix.split(".")[-1]), cache_config, quant_config),
|
||||
lambda prefix: OlmoeDecoderLayer(config,
|
||||
int(prefix.split(".")[-1]),
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
|
||||
@ -75,6 +75,7 @@ class OrionAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -126,7 +127,8 @@ class OrionAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -150,6 +152,7 @@ class OrionDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -166,6 +169,7 @@ class OrionDecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = OrionMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -226,10 +230,7 @@ class OrionModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: OrionDecoderLayer(
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
),
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module):
|
||||
def __init__(self,
|
||||
config: PersimmonConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
tensor_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
@ -122,7 +123,8 @@ class PersimmonAttention(nn.Module):
|
||||
self.head_dim,
|
||||
scale=self.scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
|
||||
@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: PersimmonConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = PersimmonAttention(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = PersimmonMLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
@ -226,8 +230,8 @@ class PersimmonModel(nn.Module):
|
||||
config.hidden_size)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: PersimmonDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
lambda prefix: PersimmonDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -69,7 +69,8 @@ class PhiAttention(nn.Module):
|
||||
def __init__(self,
|
||||
config: PhiConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -116,7 +117,8 @@ class PhiAttention(nn.Module):
|
||||
self.head_size,
|
||||
scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -167,11 +169,15 @@ class PhiLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: PhiConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.self_attn = PhiAttention(config, cache_config, quant_config)
|
||||
self.self_attn = PhiAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = PhiMLP(config, quant_config)
|
||||
|
||||
def forward(
|
||||
@ -210,7 +216,8 @@ class PhiModel(nn.Module):
|
||||
config.hidden_size)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: PhiLayer(config, cache_config, quant_config),
|
||||
lambda prefix: PhiLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -117,6 +117,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
@ -214,15 +215,14 @@ class Phi3SmallSelfAttention(nn.Module):
|
||||
"homo_head": self.homo_heads
|
||||
}
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads_per_partion,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
blocksparse_params=bs_params,
|
||||
)
|
||||
self.attn = Attention(self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads_per_partion,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
blocksparse_params=bs_params,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -259,13 +259,15 @@ class Phi3SmallDecoderLayer(nn.Module):
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Phi3SmallSelfAttention(config,
|
||||
layer_idx,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Phi3SmallMLP(config, quant_config)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
@ -315,7 +317,9 @@ class Phi3SmallModel(nn.Module):
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Phi3SmallDecoderLayer(config,
|
||||
int(prefix.split('.')[-1]),
|
||||
cache_config, quant_config),
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
|
||||
@ -294,6 +294,7 @@ class PhiMoEAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -347,6 +348,7 @@ class PhiMoEAttention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -371,6 +373,7 @@ class PhiMoEDecoderLayer(nn.Module):
|
||||
config: PhiMoEConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -385,6 +388,7 @@ class PhiMoEDecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=config.rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.block_sparse_moe = PhiMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
@ -454,8 +458,8 @@ class PhiMoEModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: PhiMoEDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
lambda prefix: PhiMoEDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
|
||||
@ -442,6 +442,7 @@ class QWenAttention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -478,7 +479,8 @@ class QWenAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -502,6 +504,7 @@ class QWenBlock(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@ -514,7 +517,8 @@ class QWenBlock(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
@ -568,7 +572,8 @@ class QWenModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: QWenBlock(config, cache_config, quant_config),
|
||||
lambda prefix: QWenBlock(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.h")
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -168,6 +168,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -220,7 +221,8 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -245,6 +247,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
@ -336,7 +340,8 @@ class Qwen2MoeModel(nn.Module):
|
||||
layer_idx=int(
|
||||
prefix.split(".")[-1]),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -167,6 +167,7 @@ class SolarAttention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -77,7 +77,8 @@ class StablelmAttention(nn.Module):
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -131,7 +132,8 @@ class StablelmAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_key_value_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -155,9 +157,13 @@ class StablelmDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = StablelmAttention(config, cache_config, quant_config)
|
||||
self.self_attn = StablelmAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = StablelmMLP(config, quant_config)
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
getattr(config, "layer_norm_eps", 1e-05))
|
||||
@ -207,8 +213,8 @@ class StableLMEpochModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: StablelmDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
lambda prefix: StablelmDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
|
||||
@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Starcoder2Attention(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.norm_epsilon)
|
||||
@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Starcoder2DecoderLayer(
|
||||
config, cache_config, quant_config=quant_config),
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
||||
|
||||
@ -93,6 +93,7 @@ class XverseAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -138,7 +139,8 @@ class XverseAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "bias", False),
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = XverseMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -243,8 +247,8 @@ class XverseModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: XverseDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
lambda prefix: XverseDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -20,6 +20,7 @@ logger = init_logger(__name__)
|
||||
class CpuPlatform(Platform):
|
||||
_enum = PlatformEnum.CPU
|
||||
device_type: str = "cpu"
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
|
||||
@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
class CudaPlatform(Platform):
|
||||
_enum = PlatformEnum.CUDA
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
|
||||
@ -13,6 +13,7 @@ else:
|
||||
class HpuPlatform(Platform):
|
||||
_enum = PlatformEnum.HPU
|
||||
device_type: str = "hpu"
|
||||
dispatch_key: str = "HPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
|
||||
class Platform:
|
||||
_enum: PlatformEnum
|
||||
device_type: str
|
||||
# available dispatch keys:
|
||||
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
def is_cuda(self) -> bool:
|
||||
return self._enum == PlatformEnum.CUDA
|
||||
|
||||
@ -18,6 +18,7 @@ logger = init_logger(__name__)
|
||||
class OpenVinoPlatform(Platform):
|
||||
_enum = PlatformEnum.OPENVINO
|
||||
device_type: str = "openvino"
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@ -17,6 +17,7 @@ logger = init_logger(__name__)
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
device_type: str = "tpu"
|
||||
dispatch_key: str = "XLA"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@ -17,6 +17,7 @@ logger = init_logger(__name__)
|
||||
class XPUPlatform(Platform):
|
||||
_enum = PlatformEnum.XPU
|
||||
device_type: str = "xpu"
|
||||
dispatch_key: str = "XPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
if previous_hidden_states is not None else {}
|
||||
|
||||
# Run model
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
|
||||
@ -1573,6 +1573,7 @@ def direct_register_custom_op(
|
||||
mutates_args: List[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
):
|
||||
"""
|
||||
`torch.library.custom_op` can have significant overhead because it
|
||||
@ -1601,7 +1602,7 @@ def direct_register_custom_op(
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
my_lib = target_lib or vllm_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
||||
|
||||
@ -173,7 +173,8 @@ def unified_v1_flash_attention(
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
current_metadata = get_forward_context()
|
||||
context = get_forward_context()
|
||||
current_metadata = context.dynamic_forward_context
|
||||
if current_metadata is None:
|
||||
# Profiling run.
|
||||
return
|
||||
|
||||
@ -447,7 +447,7 @@ class GPUModelRunner:
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata):
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=None,
|
||||
positions=self.positions[:num_input_tokens],
|
||||
@ -523,7 +523,7 @@ class GPUModelRunner:
|
||||
num_tokens: int,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
with set_forward_context(None):
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
hidden_states = model(
|
||||
input_ids=None,
|
||||
positions=self.positions[:num_tokens],
|
||||
|
||||
@ -97,7 +97,7 @@ class EmbeddingModelRunner(
|
||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
|
||||
@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
|
||||
@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self._update_inputs_to_capture_for_enc_dec_model(
|
||||
capture_inputs)
|
||||
|
||||
with set_forward_context(attn_metadata):
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
graph_runner.capture(**capture_inputs)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[virtual_engine][batch_size] = (
|
||||
@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user