mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 12:49:09 +08:00
[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
db100c5cde
commit
eebad39f26
@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
|||||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||||
global_force_attn_backend_context_manager)
|
global_force_attn_backend_context_manager)
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.plugins import set_current_vllm_config
|
||||||
|
|
||||||
# List of support backends for encoder/decoder models
|
# List of support backends for encoder/decoder models
|
||||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||||
@ -594,6 +596,7 @@ def _run_encoder_attention_test(
|
|||||||
encoder_test_params: PhaseTestParameters,
|
encoder_test_params: PhaseTestParameters,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
test_pt: TestPoint,
|
test_pt: TestPoint,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Run encoder attention.
|
Run encoder attention.
|
||||||
@ -623,7 +626,7 @@ def _run_encoder_attention_test(
|
|||||||
attn_type = AttentionType.ENCODER
|
attn_type = AttentionType.ENCODER
|
||||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||||
assert packed_qkv is not None
|
assert packed_qkv is not None
|
||||||
with set_forward_context(attn_metadata):
|
with set_forward_context(attn_metadata, vllm_config):
|
||||||
# In the test setup the shape of the query is
|
# In the test setup the shape of the query is
|
||||||
# [batch_size, seq_len, num_heads, head_size]. However
|
# [batch_size, seq_len, num_heads, head_size]. However
|
||||||
# the attention backend expect the shape to be
|
# the attention backend expect the shape to be
|
||||||
@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
|
|||||||
decoder_test_params: PhaseTestParameters,
|
decoder_test_params: PhaseTestParameters,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
test_pt: TestPoint,
|
test_pt: TestPoint,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Run decoder self-attention test.
|
Run decoder self-attention test.
|
||||||
@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
|
|||||||
kv_cache = test_rsrcs.kv_cache
|
kv_cache = test_rsrcs.kv_cache
|
||||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||||
assert packed_qkv is not None
|
assert packed_qkv is not None
|
||||||
with set_forward_context(attn_metadata):
|
with set_forward_context(attn_metadata, vllm_config):
|
||||||
# In the test setup the shape of the query is
|
# In the test setup the shape of the query is
|
||||||
# [batch_size, seq_len, num_heads, head_size]. However
|
# [batch_size, seq_len, num_heads, head_size]. However
|
||||||
# the attention backend expect the shape to be
|
# the attention backend expect the shape to be
|
||||||
@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
cross_test_params: Optional[PhaseTestParameters],
|
cross_test_params: Optional[PhaseTestParameters],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
test_pt: TestPoint,
|
test_pt: TestPoint,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Run encoder/decoder cross-attention test.
|
Run encoder/decoder cross-attention test.
|
||||||
@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||||
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||||
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||||
with set_forward_context(attn_metadata):
|
with set_forward_context(attn_metadata, vllm_config):
|
||||||
# In the test setup the shape of the query is
|
# In the test setup the shape of the query is
|
||||||
# [batch_size, seq_len, num_heads, head_size]. However
|
# [batch_size, seq_len, num_heads, head_size]. However
|
||||||
# the attention backend expect the shape to be
|
# the attention backend expect the shape to be
|
||||||
@ -839,7 +844,9 @@ def test_encoder_only(
|
|||||||
|
|
||||||
# Attention scale factor, attention backend instance, attention wrapper
|
# Attention scale factor, attention backend instance, attention wrapper
|
||||||
# instance, KV cache init
|
# instance, KV cache init
|
||||||
test_rsrcs = _make_test_resources(test_pt)
|
vllm_config = VllmConfig()
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
test_rsrcs = _make_test_resources(test_pt)
|
||||||
|
|
||||||
# Construct encoder attention test params (only used
|
# Construct encoder attention test params (only used
|
||||||
# during prefill)
|
# during prefill)
|
||||||
@ -863,7 +870,8 @@ def test_encoder_only(
|
|||||||
test_rsrcs.attn,
|
test_rsrcs.attn,
|
||||||
enc_test_params,
|
enc_test_params,
|
||||||
prephase_attn_metadata,
|
prephase_attn_metadata,
|
||||||
test_pt=test_pt))
|
test_pt=test_pt,
|
||||||
|
vllm_config=vllm_config))
|
||||||
|
|
||||||
# - Is encoder attention result correct?
|
# - Is encoder attention result correct?
|
||||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||||
@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
|
|||||||
|
|
||||||
# Attention scale factor, attention backend instance, attention wrapper
|
# Attention scale factor, attention backend instance, attention wrapper
|
||||||
# instance, KV cache init
|
# instance, KV cache init
|
||||||
test_rsrcs = _make_test_resources(test_pt)
|
vllm_config = VllmConfig()
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
test_rsrcs = _make_test_resources(test_pt)
|
||||||
|
|
||||||
# Construct encoder attention test params (only used
|
# Construct encoder attention test params (only used
|
||||||
# during prefill)
|
# during prefill)
|
||||||
@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
|
|||||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||||
enc_test_params,
|
enc_test_params,
|
||||||
prephase_attn_metadata,
|
prephase_attn_metadata,
|
||||||
test_pt=test_pt)
|
test_pt=test_pt,
|
||||||
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is encoder attention result correct?
|
# - Is encoder attention result correct?
|
||||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||||
@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
|
|||||||
test_rsrcs,
|
test_rsrcs,
|
||||||
prephase_dec_test_params,
|
prephase_dec_test_params,
|
||||||
prephase_attn_metadata,
|
prephase_attn_metadata,
|
||||||
test_pt=test_pt)
|
test_pt=test_pt,
|
||||||
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is prefill decoder self-attention correct?
|
# - Is prefill decoder self-attention correct?
|
||||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||||
@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
|
|||||||
prephase_dec_test_params,
|
prephase_dec_test_params,
|
||||||
prephase_cross_test_params,
|
prephase_cross_test_params,
|
||||||
prephase_attn_metadata,
|
prephase_attn_metadata,
|
||||||
test_pt=test_pt)
|
test_pt=test_pt,
|
||||||
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is prefill encoder/decoder cross-attention correct?
|
# - Is prefill encoder/decoder cross-attention correct?
|
||||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||||
@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
|
|||||||
test_rsrcs,
|
test_rsrcs,
|
||||||
decphase_dec_test_params,
|
decphase_dec_test_params,
|
||||||
decphase_attn_metadata,
|
decphase_attn_metadata,
|
||||||
test_pt=test_pt)
|
test_pt=test_pt,
|
||||||
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is decode-phase decoder self-attention correct?
|
# - Is decode-phase decoder self-attention correct?
|
||||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||||
@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
|
|||||||
decphase_dec_test_params,
|
decphase_dec_test_params,
|
||||||
None,
|
None,
|
||||||
decphase_attn_metadata,
|
decphase_attn_metadata,
|
||||||
test_pt=test_pt)
|
test_pt=test_pt,
|
||||||
|
vllm_config=vllm_config)
|
||||||
|
|
||||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
||||||
Tuple, Type, TypeVar)
|
Tuple, Type, TypeVar)
|
||||||
|
|
||||||
@ -15,13 +14,19 @@ if TYPE_CHECKING:
|
|||||||
ModelRunnerInputBuilderBase)
|
ModelRunnerInputBuilderBase)
|
||||||
|
|
||||||
|
|
||||||
class AttentionType(Enum):
|
class AttentionType:
|
||||||
DECODER = auto() # Decoder attention between previous layer Q/K/V
|
"""
|
||||||
ENCODER = auto(
|
Attention type.
|
||||||
) # Encoder attention between previous layer Q/K/V for encoder-decoder
|
Use string to be compatible with `torch.compile`.
|
||||||
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
|
"""
|
||||||
ENCODER_DECODER = auto(
|
# Decoder attention between previous layer Q/K/V
|
||||||
) # Attention between dec. Q and enc. K/V for encoder-decoder
|
DECODER = "decoder"
|
||||||
|
# Encoder attention between previous layer Q/K/V for encoder-decoder
|
||||||
|
ENCODER = "encoder"
|
||||||
|
# Encoder attention between previous layer Q/K/V
|
||||||
|
ENCODER_ONLY = "encoder_only"
|
||||||
|
# Attention between dec. Q and enc. K/V for encoder-decoder
|
||||||
|
ENCODER_DECODER = "encoder_decoder"
|
||||||
|
|
||||||
|
|
||||||
class AttentionBackend(ABC):
|
class AttentionBackend(ABC):
|
||||||
@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
attn_metadata: T,
|
attn_metadata: T,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -16,10 +16,8 @@ from vllm.attention.backends.utils import (
|
|||||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
make_tensor_with_pad)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
@ -668,23 +666,174 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"requires setting cross-attention "
|
"requires setting cross-attention "
|
||||||
"metadata attributes.")
|
"metadata attributes.")
|
||||||
|
|
||||||
output = torch.ops.vllm.unified_flash_attention(
|
num_heads: int = self.num_heads
|
||||||
query,
|
head_size: int = self.head_size
|
||||||
key,
|
num_kv_heads: int = self.num_kv_heads
|
||||||
value,
|
kv_cache_dtype: str = self.kv_cache_dtype
|
||||||
self.num_heads,
|
softmax_scale: float = self.scale
|
||||||
self.head_size,
|
window_size = self.sliding_window
|
||||||
self.num_kv_heads,
|
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
||||||
kv_cache,
|
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
||||||
self.kv_cache_dtype,
|
|
||||||
k_scale,
|
num_tokens, hidden_size = query.shape
|
||||||
v_scale,
|
|
||||||
self.scale,
|
# Reshape the query, key, and value tensors.
|
||||||
attn_type.value,
|
query = query.view(-1, num_heads, head_size)
|
||||||
self.sliding_window,
|
if (key is not None) and (value is not None):
|
||||||
self.alibi_slopes,
|
key = key.view(-1, num_kv_heads, head_size)
|
||||||
self.logits_soft_cap,
|
value = value.view(-1, num_kv_heads, head_size)
|
||||||
)
|
|
||||||
|
if kv_cache.numel() > 0:
|
||||||
|
key_cache = kv_cache[0]
|
||||||
|
value_cache = kv_cache[1]
|
||||||
|
# We skip updating the KV cache under two conditions:
|
||||||
|
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||||
|
# only the encoder attention without updating the cache.
|
||||||
|
# b. When both Key and Value are None. This occurs during
|
||||||
|
# cross-attention computation in the decoding phase, where the
|
||||||
|
# KV cache is already populated with the cross-attention
|
||||||
|
# tensor. Thus, we skip cache updates during this time.
|
||||||
|
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||||
|
value is not None):
|
||||||
|
if attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
# Update cross-attention KV cache (prefill-only)
|
||||||
|
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||||
|
else:
|
||||||
|
# Update self-attention KV cache (prefill/decode)
|
||||||
|
updated_slot_mapping = attn_metadata.slot_mapping
|
||||||
|
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
|
# not cached. This happens during the initial memory
|
||||||
|
# profiling run.
|
||||||
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
|
num_decode_query_tokens) = \
|
||||||
|
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||||
|
decode_query = query[num_prefill_query_tokens:]
|
||||||
|
# QKV for prefill.
|
||||||
|
query = query[:num_prefill_query_tokens]
|
||||||
|
assert query.shape[0] == num_prefill_query_tokens
|
||||||
|
assert decode_query.shape[0] == num_decode_query_tokens
|
||||||
|
|
||||||
|
prefill_output: Optional[torch.Tensor] = None
|
||||||
|
decode_output: Optional[torch.Tensor] = None
|
||||||
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
|
# Prompt run.
|
||||||
|
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||||
|
or prefill_meta.block_tables.numel() == 0):
|
||||||
|
# normal attention
|
||||||
|
# When block_tables are not filled, it means q and k are the
|
||||||
|
# prompt, and they have the same length.
|
||||||
|
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||||
|
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||||
|
|
||||||
|
key = key[:num_prefill_kv_tokens]
|
||||||
|
value = value[:num_prefill_kv_tokens]
|
||||||
|
|
||||||
|
prefill_output = flash_attn_varlen_func(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
cu_seqlens_q=q_seq_start_loc,
|
||||||
|
cu_seqlens_k=k_seq_start_loc,
|
||||||
|
max_seqlen_q=q_seq_len,
|
||||||
|
max_seqlen_k=k_seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=_get_causal_option(attn_type),
|
||||||
|
window_size=window_size,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
softcap=logits_soft_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# prefix-enabled attention
|
||||||
|
assert attn_type == AttentionType.DECODER, (
|
||||||
|
"Only decoder-only models support prefix caching")
|
||||||
|
assert prefill_meta.seq_lens is not None
|
||||||
|
max_seq_len = max(prefill_meta.seq_lens)
|
||||||
|
prefill_output = flash_attn_varlen_func( # noqa
|
||||||
|
q=query,
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||||
|
max_seqlen_q=prefill_meta.max_query_len,
|
||||||
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
|
max_seqlen_k=max_seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
block_table=prefill_meta.block_tables,
|
||||||
|
softcap=logits_soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
# Decoding run.
|
||||||
|
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||||
|
# because different queries might have different lengths.
|
||||||
|
|
||||||
|
assert decode_meta.max_decode_query_len is not None
|
||||||
|
# use only for actual varlen decoding
|
||||||
|
if decode_meta.max_decode_query_len > 1:
|
||||||
|
assert attn_type == AttentionType.DECODER, (
|
||||||
|
"Only decoder-only models support max_decode_query_len > 1"
|
||||||
|
)
|
||||||
|
decode_output = flash_attn_varlen_func(
|
||||||
|
q=decode_query,
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
cu_seqlens_q=decode_meta.query_start_loc,
|
||||||
|
max_seqlen_q=decode_meta.max_decode_query_len,
|
||||||
|
cu_seqlens_k=decode_meta.seq_start_loc,
|
||||||
|
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
softcap=logits_soft_cap,
|
||||||
|
block_table=decode_meta.block_tables,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use flash_attn_with_kvcache for normal decoding.
|
||||||
|
(
|
||||||
|
seq_lens_arg,
|
||||||
|
_,
|
||||||
|
block_tables_arg,
|
||||||
|
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||||
|
decode_output = flash_attn_with_kvcache(
|
||||||
|
q=decode_query.unsqueeze(1),
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
block_table=block_tables_arg,
|
||||||
|
cache_seqlens=seq_lens_arg,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
softcap=logits_soft_cap,
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
if prefill_output is None:
|
||||||
|
assert decode_output is not None
|
||||||
|
return decode_output.view(num_decode_query_tokens, hidden_size)
|
||||||
|
if decode_output is None:
|
||||||
|
assert prefill_output is not None
|
||||||
|
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
||||||
|
|
||||||
|
assert decode_meta is not None
|
||||||
|
decode_output = decode_output.squeeze(1)
|
||||||
|
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||||
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -692,7 +841,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
def _get_query_key_seq_metadata(
|
def _get_query_key_seq_metadata(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
is_prompt: bool,
|
is_prompt: bool,
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Returns sequence metadata for key and query based on the specified
|
Returns sequence metadata for key and query based on the specified
|
||||||
@ -754,7 +903,7 @@ def _get_query_key_seq_metadata(
|
|||||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
def _get_causal_option(attn_type: AttentionType) -> bool:
|
def _get_causal_option(attn_type: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Determine whether the given attention type is suitable for causal
|
Determine whether the given attention type is suitable for causal
|
||||||
attention mechanisms.
|
attention mechanisms.
|
||||||
@ -770,220 +919,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool:
|
|||||||
return not (attn_type == AttentionType.ENCODER
|
return not (attn_type == AttentionType.ENCODER
|
||||||
or attn_type == AttentionType.ENCODER_ONLY
|
or attn_type == AttentionType.ENCODER_ONLY
|
||||||
or attn_type == AttentionType.ENCODER_DECODER)
|
or attn_type == AttentionType.ENCODER_DECODER)
|
||||||
|
|
||||||
|
|
||||||
def unified_flash_attention(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
k_scale: float,
|
|
||||||
v_scale: float,
|
|
||||||
softmax_scale: float,
|
|
||||||
attn_type_int_val: int,
|
|
||||||
window_size: Optional[List[int]] = None,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
logits_soft_cap: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
# Convert integer attn_type to enum
|
|
||||||
try:
|
|
||||||
attn_type = AttentionType(attn_type_int_val)
|
|
||||||
except ValueError as err:
|
|
||||||
raise AttributeError(
|
|
||||||
f"Invalid attention type {str(attn_type_int_val)}") from err
|
|
||||||
|
|
||||||
current_metadata = get_forward_context()
|
|
||||||
assert current_metadata is not None
|
|
||||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
|
||||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
|
||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
|
||||||
query = query.view(-1, num_heads, head_size)
|
|
||||||
if (key is not None) and (value is not None):
|
|
||||||
key = key.view(-1, num_kv_heads, head_size)
|
|
||||||
value = value.view(-1, num_kv_heads, head_size)
|
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
|
||||||
key_cache = kv_cache[0]
|
|
||||||
value_cache = kv_cache[1]
|
|
||||||
# We skip updating the KV cache under two conditions:
|
|
||||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
|
||||||
# only the encoder attention without updating the cache.
|
|
||||||
# b. When both Key and Value are None. This occurs during
|
|
||||||
# cross-attention computation in the decoding phase, where the KV
|
|
||||||
# cache is already populated with the cross-attention tensor.
|
|
||||||
# Thus, we skip cache updates during this time.
|
|
||||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
|
||||||
value is not None):
|
|
||||||
if attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
# Update cross-attention KV cache (prefill-only)
|
|
||||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
|
||||||
else:
|
|
||||||
# Update self-attention KV cache (prefill/decode)
|
|
||||||
updated_slot_mapping = attn_metadata.slot_mapping
|
|
||||||
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# If kv_cache is not provided, the new key and value tensors are
|
|
||||||
# not cached. This happens during the initial memory profiling run.
|
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache[0],
|
|
||||||
kv_cache[1],
|
|
||||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
|
||||||
kv_cache_dtype,
|
|
||||||
k_scale,
|
|
||||||
v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
|
||||||
num_decode_query_tokens) = \
|
|
||||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
|
||||||
decode_query = query[num_prefill_query_tokens:]
|
|
||||||
# QKV for prefill.
|
|
||||||
query = query[:num_prefill_query_tokens]
|
|
||||||
assert query.shape[0] == num_prefill_query_tokens
|
|
||||||
assert decode_query.shape[0] == num_decode_query_tokens
|
|
||||||
|
|
||||||
prefill_output: Optional[torch.Tensor] = None
|
|
||||||
decode_output: Optional[torch.Tensor] = None
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
|
||||||
# Prompt run.
|
|
||||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
|
||||||
or prefill_meta.block_tables.numel() == 0):
|
|
||||||
# normal attention
|
|
||||||
# When block_tables are not filled, it means q and k are the
|
|
||||||
# prompt, and they have the same length.
|
|
||||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
|
||||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
|
||||||
|
|
||||||
key = key[:num_prefill_kv_tokens]
|
|
||||||
value = value[:num_prefill_kv_tokens]
|
|
||||||
|
|
||||||
prefill_output = flash_attn_varlen_func(
|
|
||||||
q=query,
|
|
||||||
k=key,
|
|
||||||
v=value,
|
|
||||||
cu_seqlens_q=q_seq_start_loc,
|
|
||||||
cu_seqlens_k=k_seq_start_loc,
|
|
||||||
max_seqlen_q=q_seq_len,
|
|
||||||
max_seqlen_k=k_seq_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=_get_causal_option(attn_type),
|
|
||||||
window_size=window_size,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
softcap=logits_soft_cap,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# prefix-enabled attention
|
|
||||||
assert attn_type == AttentionType.DECODER, (
|
|
||||||
"Only decoder-only models support prefix caching")
|
|
||||||
assert prefill_meta.seq_lens is not None
|
|
||||||
max_seq_len = max(prefill_meta.seq_lens)
|
|
||||||
prefill_output = flash_attn_varlen_func( # noqa
|
|
||||||
q=query,
|
|
||||||
k=key_cache,
|
|
||||||
v=value_cache,
|
|
||||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
|
||||||
max_seqlen_q=prefill_meta.max_query_len,
|
|
||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
|
||||||
max_seqlen_k=max_seq_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
block_table=prefill_meta.block_tables,
|
|
||||||
softcap=logits_soft_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
|
||||||
# Decoding run.
|
|
||||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
|
||||||
# because different queries might have different lengths.
|
|
||||||
|
|
||||||
assert decode_meta.max_decode_query_len is not None
|
|
||||||
# use only for actual varlen decoding
|
|
||||||
if decode_meta.max_decode_query_len > 1:
|
|
||||||
assert attn_type == AttentionType.DECODER, (
|
|
||||||
"Only decoder-only models support max_decode_query_len > 1")
|
|
||||||
decode_output = flash_attn_varlen_func(
|
|
||||||
q=decode_query,
|
|
||||||
k=key_cache,
|
|
||||||
v=value_cache,
|
|
||||||
cu_seqlens_q=decode_meta.query_start_loc,
|
|
||||||
max_seqlen_q=decode_meta.max_decode_query_len,
|
|
||||||
cu_seqlens_k=decode_meta.seq_start_loc,
|
|
||||||
max_seqlen_k=decode_meta.max_decode_seq_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
softcap=logits_soft_cap,
|
|
||||||
block_table=decode_meta.block_tables,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use flash_attn_with_kvcache for normal decoding.
|
|
||||||
(
|
|
||||||
seq_lens_arg,
|
|
||||||
_,
|
|
||||||
block_tables_arg,
|
|
||||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
|
||||||
decode_output = flash_attn_with_kvcache(
|
|
||||||
q=decode_query.unsqueeze(1),
|
|
||||||
k_cache=key_cache,
|
|
||||||
v_cache=value_cache,
|
|
||||||
block_table=block_tables_arg,
|
|
||||||
cache_seqlens=seq_lens_arg,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
softcap=logits_soft_cap,
|
|
||||||
).squeeze(1)
|
|
||||||
|
|
||||||
if prefill_output is None:
|
|
||||||
assert decode_output is not None
|
|
||||||
return decode_output.view(num_decode_query_tokens, hidden_size)
|
|
||||||
if decode_output is None:
|
|
||||||
assert prefill_output is not None
|
|
||||||
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
|
||||||
|
|
||||||
assert decode_meta is not None
|
|
||||||
decode_output = decode_output.squeeze(1)
|
|
||||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
|
||||||
return output.view(num_tokens, hidden_size)
|
|
||||||
|
|
||||||
|
|
||||||
def unified_flash_attention_fake(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
k_scale: float,
|
|
||||||
v_scale: float,
|
|
||||||
softmax_scale: float,
|
|
||||||
attn_type_int_val: int,
|
|
||||||
window_size: Optional[List[int]] = None,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
logits_soft_cap: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty_like(query)
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="unified_flash_attention",
|
|
||||||
op_func=unified_flash_attention,
|
|
||||||
mutates_args=["kv_cache"],
|
|
||||||
fake_impl=unified_flash_attention_fake,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
|||||||
compute_slot_mapping_start_idx,
|
compute_slot_mapping_start_idx,
|
||||||
is_block_tables_empty)
|
is_block_tables_empty)
|
||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
make_tensor_with_pad)
|
||||||
get_kv_cache_torch_dtype, make_tensor_with_pad)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
attn_metadata: FlashInferMetadata,
|
attn_metadata: FlashInferMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
@ -782,174 +781,117 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashInferImpl")
|
"FlashInferImpl")
|
||||||
|
|
||||||
return torch.ops.vllm.unified_flash_infer(
|
num_heads: int = self.num_heads
|
||||||
query,
|
head_size: int = self.head_size
|
||||||
key,
|
num_kv_heads: int = self.num_kv_heads
|
||||||
value,
|
kv_cache_dtype: str = self.kv_cache_dtype
|
||||||
self.num_heads,
|
softmax_scale: float = self.scale
|
||||||
self.head_size,
|
window_size = self.sliding_window
|
||||||
self.num_kv_heads,
|
alibi_slopes = self.alibi_slopes
|
||||||
kv_cache,
|
logits_soft_cap = self.logits_soft_cap
|
||||||
self.kv_cache_dtype,
|
|
||||||
k_scale,
|
|
||||||
v_scale,
|
|
||||||
self.scale,
|
|
||||||
self.sliding_window,
|
|
||||||
self.alibi_slopes,
|
|
||||||
self.logits_soft_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
num_tokens, hidden_size = query.shape
|
||||||
|
query = query.view(-1, num_heads, head_size)
|
||||||
|
key = key.view(-1, num_kv_heads, head_size)
|
||||||
|
value = value.view(-1, num_kv_heads, head_size)
|
||||||
|
|
||||||
def unified_flash_infer(
|
if kv_cache.numel() > 0:
|
||||||
query: torch.Tensor,
|
# Use the same reshape and cache kernel as flash attention.
|
||||||
key: torch.Tensor,
|
ops.reshape_and_cache_flash(
|
||||||
value: torch.Tensor,
|
key,
|
||||||
num_heads: int,
|
value,
|
||||||
head_size: int,
|
kv_cache[:, 0],
|
||||||
num_kv_heads: int,
|
kv_cache[:, 1],
|
||||||
kv_cache: torch.Tensor,
|
attn_metadata.slot_mapping.flatten(),
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype,
|
||||||
k_scale: float,
|
k_scale,
|
||||||
v_scale: float,
|
v_scale,
|
||||||
softmax_scale: float,
|
|
||||||
window_size: Optional[List[int]] = None,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
logits_soft_cap: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
current_metadata = get_forward_context()
|
|
||||||
assert current_metadata is not None
|
|
||||||
assert isinstance(current_metadata, FlashInferMetadata)
|
|
||||||
attn_metadata: FlashInferMetadata = current_metadata
|
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
|
||||||
query = query.view(-1, num_heads, head_size)
|
|
||||||
key = key.view(-1, num_kv_heads, head_size)
|
|
||||||
value = value.view(-1, num_kv_heads, head_size)
|
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
|
||||||
# Use the same reshape and cache kernel as flash attention.
|
|
||||||
ops.reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache[:, 0],
|
|
||||||
kv_cache[:, 1],
|
|
||||||
attn_metadata.slot_mapping.flatten(),
|
|
||||||
kv_cache_dtype,
|
|
||||||
k_scale,
|
|
||||||
v_scale,
|
|
||||||
)
|
|
||||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
|
||||||
# to process the cache when the kv_cache_dtype is fp8
|
|
||||||
if kv_cache_dtype.startswith("fp8"):
|
|
||||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
|
||||||
kv_cache_dtype)
|
|
||||||
kv_cache = kv_cache.view(torch_dtype)
|
|
||||||
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
||||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
|
||||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
|
||||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
|
||||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
|
||||||
query = query.contiguous() # Flashinfer requires query to be contiguous
|
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
|
||||||
# QKV for prefill.
|
|
||||||
decode_query = query[num_prefill_tokens:]
|
|
||||||
query = query[:num_prefill_tokens]
|
|
||||||
|
|
||||||
key = key[:num_prefill_tokens]
|
|
||||||
value = value[:num_prefill_tokens]
|
|
||||||
|
|
||||||
assert query.shape[0] == num_prefill_tokens
|
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
|
||||||
|
|
||||||
window_left = window_size[0] if window_size is not None else -1
|
|
||||||
|
|
||||||
prefill_output: Optional[torch.Tensor] = None
|
|
||||||
decode_output: Optional[torch.Tensor] = None
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
|
||||||
# We will use flash attention for prefill
|
|
||||||
# when kv_cache is not provided.
|
|
||||||
# This happens when vllm runs the profiling to
|
|
||||||
# determine the number of blocks.
|
|
||||||
if kv_cache.numel() == 0:
|
|
||||||
prefill_output = flash_attn_varlen_func(
|
|
||||||
q=query,
|
|
||||||
k=key,
|
|
||||||
v=value,
|
|
||||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
|
||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
|
||||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
|
||||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
)
|
)
|
||||||
else:
|
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||||
assert prefill_meta is not None
|
# to process the cache when the kv_cache_dtype is fp8
|
||||||
assert prefill_meta.prefill_wrapper is not None
|
if kv_cache_dtype.startswith("fp8"):
|
||||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
query,
|
kv_cache_dtype)
|
||||||
|
kv_cache = kv_cache.view(torch_dtype)
|
||||||
|
|
||||||
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||||
|
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||||
|
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||||
|
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||||
|
query = query.contiguous(
|
||||||
|
) # Flashinfer requires query to be contiguous
|
||||||
|
# Query for decode. KV is not needed because it is already cached.
|
||||||
|
# QKV for prefill.
|
||||||
|
decode_query = query[num_prefill_tokens:]
|
||||||
|
query = query[:num_prefill_tokens]
|
||||||
|
|
||||||
|
key = key[:num_prefill_tokens]
|
||||||
|
value = value[:num_prefill_tokens]
|
||||||
|
|
||||||
|
assert query.shape[0] == num_prefill_tokens
|
||||||
|
assert decode_query.shape[0] == num_decode_tokens
|
||||||
|
|
||||||
|
window_left = window_size[0] if window_size is not None else -1
|
||||||
|
|
||||||
|
prefill_output: Optional[torch.Tensor] = None
|
||||||
|
decode_output: Optional[torch.Tensor] = None
|
||||||
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
|
# We will use flash attention for prefill
|
||||||
|
# when kv_cache is not provided.
|
||||||
|
# This happens when vllm runs the profiling to
|
||||||
|
# determine the number of blocks.
|
||||||
|
if kv_cache.numel() == 0:
|
||||||
|
prefill_output = flash_attn_varlen_func(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||||
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
|
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||||
|
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert prefill_meta is not None
|
||||||
|
assert prefill_meta.prefill_wrapper is not None
|
||||||
|
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
causal=True,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
window_left=window_left)
|
||||||
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
assert decode_meta is not None
|
||||||
|
assert decode_meta.decode_wrapper is not None
|
||||||
|
decode_output = decode_meta.decode_wrapper.forward(
|
||||||
|
decode_query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
|
sm_scale=softmax_scale,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
causal=True,
|
|
||||||
k_scale=k_scale,
|
k_scale=k_scale,
|
||||||
v_scale=v_scale,
|
v_scale=v_scale,
|
||||||
window_left=window_left)
|
window_left=window_left)
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
|
||||||
assert attn_metadata.decode_metadata is not None
|
|
||||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
|
||||||
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
|
||||||
decode_query,
|
|
||||||
kv_cache,
|
|
||||||
sm_scale=softmax_scale,
|
|
||||||
logits_soft_cap=logits_soft_cap,
|
|
||||||
k_scale=k_scale,
|
|
||||||
v_scale=v_scale,
|
|
||||||
window_left=window_left)
|
|
||||||
|
|
||||||
if prefill_output is None and decode_output is not None:
|
if prefill_output is None and decode_output is not None:
|
||||||
# Decode only batch.
|
# Decode only batch.
|
||||||
output, num_tokens = decode_output, num_decode_tokens
|
output, num_tokens = decode_output, num_decode_tokens
|
||||||
elif decode_output is None and prefill_output is not None:
|
elif decode_output is None and prefill_output is not None:
|
||||||
# Prefill only batch.
|
# Prefill only batch.
|
||||||
output, num_tokens = prefill_output, num_prefill_tokens
|
output, num_tokens = prefill_output, num_prefill_tokens
|
||||||
else:
|
else:
|
||||||
# Chunked prefill batch does not work with speculative decoding in
|
# Chunked prefill batch does not work with speculative decoding in
|
||||||
# FlashInfer backend, so the query length for decode should be 1.
|
# FlashInfer backend, so the query length for decode should be 1.
|
||||||
assert prefill_output is not None
|
assert prefill_output is not None
|
||||||
assert decode_output is not None
|
assert decode_output is not None
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
assert decode_meta.decode_query_len == 1
|
assert decode_meta.decode_query_len == 1
|
||||||
decode_output = decode_output.squeeze(1)
|
decode_output = decode_output.squeeze(1)
|
||||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
def unified_flash_infer_fake(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
k_scale: float,
|
|
||||||
v_scale: float,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size: Optional[List[int]] = None,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
logits_soft_cap: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty_like(query).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="unified_flash_infer",
|
|
||||||
op_func=unified_flash_infer,
|
|
||||||
mutates_args=["kv_cache"],
|
|
||||||
fake_impl=unified_flash_infer_fake,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
attn_metadata: HPUAttentionMetadata,
|
attn_metadata: HPUAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata: PallasMetadata,
|
attn_metadata: PallasMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Pallas attention.
|
"""Forward pass with Pallas attention.
|
||||||
|
|
||||||
|
|||||||
@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: ROCmFlashAttentionMetadata,
|
attn_metadata: ROCmFlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
|
|
||||||
def get_seq_lens(
|
def get_seq_lens(
|
||||||
self,
|
self,
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
Extract appropriate sequence lengths from attention metadata
|
Extract appropriate sequence lengths from attention metadata
|
||||||
@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
|
|
||||||
def get_attn_bias(
|
def get_attn_bias(
|
||||||
self,
|
self,
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
) -> Optional[List[torch.Tensor]]:
|
) -> Optional[List[torch.Tensor]]:
|
||||||
'''
|
'''
|
||||||
Extract appropriate attention bias from attention metadata
|
Extract appropriate attention bias from attention metadata
|
||||||
@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
def set_attn_bias(
|
def set_attn_bias(
|
||||||
self,
|
self,
|
||||||
attn_bias: List[torch.Tensor],
|
attn_bias: List[torch.Tensor],
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
Update appropriate attention bias field of attention metadata,
|
Update appropriate attention bias field of attention metadata,
|
||||||
@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
|
|
||||||
def get_seq_len_block_table_args(
|
def get_seq_len_block_table_args(
|
||||||
self,
|
self,
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
'''
|
'''
|
||||||
The particular choice of sequence-length- and block-table-related
|
The particular choice of sequence-length- and block-table-related
|
||||||
@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
|
|
||||||
@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
attn_metadata: TorchSDPAMetadata,
|
attn_metadata: TorchSDPAMetadata,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||||
|
|||||||
@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata):
|
|||||||
def get_seq_len_block_table_args(
|
def get_seq_len_block_table_args(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
is_prompt: bool,
|
is_prompt: bool,
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
'''
|
'''
|
||||||
The particular choice of sequence-length- and block-table-related
|
The particular choice of sequence-length- and block-table-related
|
||||||
@ -529,7 +529,7 @@ def get_seq_len_block_table_args(
|
|||||||
|
|
||||||
def get_num_prefill_decode_query_kv_tokens(
|
def get_num_prefill_decode_query_kv_tokens(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
) -> Tuple[int, int, int]:
|
) -> Tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Calculate the number of prefill and decode tokens for query, key/value
|
Calculate the number of prefill and decode tokens for query, key/value
|
||||||
|
|||||||
@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
|
|
||||||
def _get_attn_bias(
|
def _get_attn_bias(
|
||||||
attn_metadata: XFormersMetadata,
|
attn_metadata: XFormersMetadata,
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
) -> Optional[AttentionBias]:
|
) -> Optional[AttentionBias]:
|
||||||
'''
|
'''
|
||||||
Extract appropriate attention bias from attention metadata
|
Extract appropriate attention bias from attention metadata
|
||||||
@ -314,7 +314,7 @@ def _get_attn_bias(
|
|||||||
def _set_attn_bias(
|
def _set_attn_bias(
|
||||||
attn_metadata: XFormersMetadata,
|
attn_metadata: XFormersMetadata,
|
||||||
attn_bias: List[Optional[AttentionBias]],
|
attn_bias: List[Optional[AttentionBias]],
|
||||||
attn_type: AttentionType,
|
attn_type: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
Update appropriate attention bias field of attention metadata,
|
Update appropriate attention bias field of attention metadata,
|
||||||
@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
attn_metadata: "XFormersMetadata",
|
attn_metadata: "XFormersMetadata",
|
||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
|
|
||||||
@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
attn_metadata: XFormersMetadata,
|
attn_metadata: XFormersMetadata,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||||
tokens are flattened in to `query` input.
|
tokens are flattened in to `query` input.
|
||||||
|
|||||||
@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionMetadata, AttentionType
|
from vllm.attention import AttentionMetadata, AttentionType
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.plugins import get_current_vllm_config
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -86,6 +91,18 @@ class Attention(nn.Module):
|
|||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap)
|
blocksparse_params, logits_soft_cap)
|
||||||
|
|
||||||
|
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||||
|
# torch.compile works by registering the attention as one giant
|
||||||
|
# opaque custom op. For other platforms, we directly call them
|
||||||
|
# and let torch.compile handle them.
|
||||||
|
self.use_direct_call = envs.VLLM_USE_V1 or not (
|
||||||
|
current_platform.is_cuda_alike() or current_platform.is_cpu())
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
self.layer_name = prefix
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -93,17 +110,22 @@ class Attention(nn.Module):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
return self.impl.forward(query,
|
if self.use_direct_call:
|
||||||
key,
|
return self.impl.forward(query,
|
||||||
value,
|
key,
|
||||||
kv_cache,
|
value,
|
||||||
attn_metadata,
|
kv_cache,
|
||||||
self._k_scale,
|
attn_metadata,
|
||||||
self._v_scale,
|
self._k_scale,
|
||||||
attn_type=attn_type)
|
self._v_scale,
|
||||||
|
attn_type=attn_type)
|
||||||
|
else:
|
||||||
|
return torch.ops.vllm.unified_attention(query, key, value,
|
||||||
|
kv_cache, attn_type,
|
||||||
|
self.layer_name)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||||
@ -112,3 +134,44 @@ class Attention(nn.Module):
|
|||||||
s += f", scale={self.impl.scale}" # type: ignore
|
s += f", scale={self.impl.scale}" # type: ignore
|
||||||
s += f", backend={self.impl.__class__.__name__}"
|
s += f", backend={self.impl.__class__.__name__}"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_type: str,
|
||||||
|
layer_name: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.dynamic_forward_context
|
||||||
|
self = forward_context.static_forward_context[layer_name]
|
||||||
|
return self.impl.forward(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
self._k_scale,
|
||||||
|
self._v_scale,
|
||||||
|
attn_type=attn_type)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_type: str,
|
||||||
|
layer_name: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(query).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_attention",
|
||||||
|
op_func=unified_attention,
|
||||||
|
mutates_args=["kv_cache"],
|
||||||
|
fake_impl=unified_attention_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|||||||
@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
|
|||||||
backend: str = ""
|
backend: str = ""
|
||||||
custom_ops: List[str] = Field(default_factory=list)
|
custom_ops: List[str] = Field(default_factory=list)
|
||||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||||
"vllm.unified_flash_attention",
|
"vllm.unified_attention",
|
||||||
"vllm.unified_flash_infer",
|
|
||||||
"vllm.unified_v1_flash_attention",
|
"vllm.unified_v1_flash_attention",
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
|
|||||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||||
|
|
||||||
|
# Per-model forward context
|
||||||
|
# Mainly used to store attention cls
|
||||||
|
# Map from layer name to the attention cls
|
||||||
|
static_forward_context: Dict[str, Any] = PrivateAttr
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
||||||
"""Parse the CLI value for the compilation config."""
|
"""Parse the CLI value for the compilation config."""
|
||||||
@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
|
|||||||
|
|
||||||
self.enabled_custom_ops = Counter()
|
self.enabled_custom_ops = Counter()
|
||||||
self.disabled_custom_ops = Counter()
|
self.disabled_custom_ops = Counter()
|
||||||
|
self.static_forward_context = {}
|
||||||
|
|
||||||
def init_backend(self) -> Union[str, Callable]:
|
def init_backend(self) -> Union[str, Callable]:
|
||||||
if self.level == CompilationLevel.NO_COMPILATION:
|
if self.level == CompilationLevel.NO_COMPILATION:
|
||||||
|
|||||||
@ -1,21 +1,38 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
_forward_context: Any = None
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
def get_forward_context() -> Any:
|
@dataclass
|
||||||
|
class ForwardContext:
|
||||||
|
static_forward_context: Dict[str, Any]
|
||||||
|
# TODO: extend to support per-layer dynamic forward context
|
||||||
|
dynamic_forward_context: Any
|
||||||
|
|
||||||
|
|
||||||
|
_forward_context: Optional[ForwardContext] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_forward_context() -> ForwardContext:
|
||||||
"""Get the current forward context."""
|
"""Get the current forward context."""
|
||||||
|
assert _forward_context is not None, (
|
||||||
|
"Forward context is not set. "
|
||||||
|
"Please use `set_forward_context` to set the forward context.")
|
||||||
return _forward_context
|
return _forward_context
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_forward_context(context: Any):
|
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc."""
|
can be attention metadata, etc."""
|
||||||
global _forward_context
|
global _forward_context
|
||||||
prev_context = _forward_context
|
prev_context = _forward_context
|
||||||
_forward_context = context
|
_forward_context = ForwardContext(
|
||||||
|
static_forward_context=vllm_config.compilation_config.
|
||||||
|
static_forward_context,
|
||||||
|
dynamic_forward_context=context)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@ -223,6 +223,7 @@ class ArcticAttention(nn.Module):
|
|||||||
layer_idx: Optional[int] = None,
|
layer_idx: Optional[int] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -274,7 +275,8 @@ class ArcticAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module):
|
|||||||
self.self_attn = ArcticAttention(config,
|
self.self_attn = ArcticAttention(config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
self.block_sparse_moe = ArcticMoE(
|
self.block_sparse_moe = ArcticMoE(
|
||||||
config,
|
config,
|
||||||
layer_id=layer_idx,
|
layer_id=layer_idx,
|
||||||
@ -380,8 +384,11 @@ class ArcticModel(nn.Module):
|
|||||||
org_num_embeddings=self.vocab_size)
|
org_num_embeddings=self.vocab_size)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: ArcticDecoderLayer(config, int(
|
lambda prefix: ArcticDecoderLayer(config,
|
||||||
prefix.split(".")[-1]), cache_config, quant_config),
|
int(prefix.split(".")[-1]),
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self._attn_implementation = config._attn_implementation
|
self._attn_implementation = config._attn_implementation
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module):
|
|||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
scaling,
|
scaling,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
else:
|
else:
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
position_embedding: str,
|
position_embedding: str,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.mlp = BaiChuanMLP(
|
self.mlp = BaiChuanMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
|
lambda prefix: BaiChuanDecoderLayer(config,
|
||||||
cache_config, quant_config),
|
position_embedding,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module):
|
|||||||
config: Optional[BartConfig] = None,
|
config: Optional[BartConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module):
|
|||||||
config: Optional[BartConfig] = None,
|
config: Optional[BartConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module):
|
|||||||
config: Optional[BartConfig] = None,
|
config: Optional[BartConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module):
|
|||||||
config: BartConfig,
|
config: BartConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.d_model
|
self.embed_dim = config.d_model
|
||||||
@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module):
|
|||||||
num_heads=config.encoder_attention_heads,
|
num_heads=config.encoder_attention_heads,
|
||||||
config=config,
|
config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
self.activation_fn = get_act_fn(config.activation_function)
|
self.activation_fn = get_act_fn(config.activation_function)
|
||||||
|
|
||||||
@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module):
|
|||||||
config: BartConfig,
|
config: BartConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.d_model
|
self.embed_dim = config.d_model
|
||||||
@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module):
|
|||||||
num_heads=config.decoder_attention_heads,
|
num_heads=config.decoder_attention_heads,
|
||||||
config=config,
|
config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
self.activation_fn = get_act_fn(config.activation_function)
|
self.activation_fn = get_act_fn(config.activation_function)
|
||||||
|
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module):
|
|||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
config.decoder_attention_heads,
|
config.decoder_attention_heads,
|
||||||
config=config,
|
config=config,
|
||||||
|
prefix=f"{prefix}.encoder_attn",
|
||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
@ -578,7 +591,8 @@ class BartEncoder(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
embed_tokens: Optional[nn.Embedding] = None):
|
embed_tokens: Optional[nn.Embedding] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -599,9 +613,13 @@ class BartEncoder(nn.Module):
|
|||||||
config.max_position_embeddings,
|
config.max_position_embeddings,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList([
|
||||||
[BartEncoderLayer(config,cache_config,quant_config) \
|
BartEncoderLayer(config,
|
||||||
for _ in range(config.encoder_layers)])
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.layers.{layer_idx}")
|
||||||
|
for layer_idx in range(config.encoder_layers)
|
||||||
|
])
|
||||||
|
|
||||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
@ -661,6 +679,7 @@ class BartDecoder(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
embed_tokens: Optional[nn.Embedding] = None,
|
embed_tokens: Optional[nn.Embedding] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -683,8 +702,9 @@ class BartDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[BartDecoderLayer(config,cache_config,quant_config) \
|
[BartDecoderLayer(config,cache_config,quant_config,
|
||||||
for _ in range(config.decoder_layers)])
|
prefix=f"{prefix}.layers.{layer_idx}") \
|
||||||
|
for layer_idx in range(config.decoder_layers)])
|
||||||
|
|
||||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
@ -759,10 +779,12 @@ class BartModel(nn.Module):
|
|||||||
|
|
||||||
self.encoder = BartEncoder(config,
|
self.encoder = BartEncoder(config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.encoder")
|
||||||
self.decoder = BartDecoder(config,
|
self.decoder = BartDecoder(config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.decoder")
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
encoder_input_ids: torch.Tensor,
|
encoder_input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -78,6 +78,7 @@ class BloomAttention(nn.Module):
|
|||||||
config: BloomConfig,
|
config: BloomConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -116,7 +117,8 @@ class BloomAttention(nn.Module):
|
|||||||
scaling,
|
scaling,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -168,14 +170,17 @@ class BloomBlock(nn.Module):
|
|||||||
config: BloomConfig,
|
config: BloomConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||||
eps=config.layer_norm_epsilon)
|
eps=config.layer_norm_epsilon)
|
||||||
self.self_attention = BloomAttention(config, cache_config,
|
self.self_attention = BloomAttention(config,
|
||||||
quant_config)
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.self_attention")
|
||||||
self.post_attention_layernorm = nn.LayerNorm(
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
hidden_size, eps=config.layer_norm_epsilon)
|
hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = BloomMLP(config, quant_config)
|
self.mlp = BloomMLP(config, quant_config)
|
||||||
@ -242,7 +247,8 @@ class BloomModel(nn.Module):
|
|||||||
# Transformer blocks
|
# Transformer blocks
|
||||||
self.start_layer, self.end_layer, self.h = make_layers(
|
self.start_layer, self.end_layer, self.h = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: BloomBlock(config, cache_config, quant_config),
|
lambda prefix: BloomBlock(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.h")
|
prefix=f"{prefix}.h")
|
||||||
|
|
||||||
# Final Layer Norm
|
# Final Layer Norm
|
||||||
|
|||||||
@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def _apply_qk_norm(self, q: torch.Tensor,
|
def _apply_qk_norm(self, q: torch.Tensor,
|
||||||
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module):
|
|||||||
config: ChameleonConfig,
|
config: ChameleonConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=False,
|
bias=False,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.mlp = ChameleonMLP(
|
self.mlp = ChameleonMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
|||||||
config: ChameleonConfig,
|
config: ChameleonConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=False,
|
bias=False,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.mlp = ChameleonMLP(
|
self.mlp = ChameleonMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -855,7 +861,8 @@ class ChameleonModel(nn.Module):
|
|||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: decoder_layer(config=config,
|
lambda prefix: decoder_layer(config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config),
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -230,6 +230,7 @@ class GLMAttention(nn.Module):
|
|||||||
config: ChatGLMConfig,
|
config: ChatGLMConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -285,7 +286,8 @@ class GLMAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -364,6 +366,7 @@ class GLMBlock(nn.Module):
|
|||||||
config: ChatGLMConfig,
|
config: ChatGLMConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.apply_residual_connection_post_layernorm = (
|
self.apply_residual_connection_post_layernorm = (
|
||||||
@ -377,7 +380,10 @@ class GLMBlock(nn.Module):
|
|||||||
eps=config.layernorm_epsilon)
|
eps=config.layernorm_epsilon)
|
||||||
|
|
||||||
# Self attention.
|
# Self attention.
|
||||||
self.self_attention = GLMAttention(config, cache_config, quant_config)
|
self.self_attention = GLMAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.self_attention")
|
||||||
self.hidden_dropout = config.hidden_dropout
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
|
||||||
# Layernorm on the attention output
|
# Layernorm on the attention output
|
||||||
@ -446,7 +452,8 @@ class GLMTransformer(nn.Module):
|
|||||||
# Transformer layers.
|
# Transformer layers.
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
lambda prefix: GLMBlock(config, cache_config, quant_config),
|
lambda prefix: GLMBlock(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -500,16 +507,22 @@ class ChatGLMModel(nn.Module):
|
|||||||
self.num_layers = config.num_layers
|
self.num_layers = config.num_layers
|
||||||
self.multi_query_group_num = config.multi_query_group_num
|
self.multi_query_group_num = config.multi_query_group_num
|
||||||
self.kv_channels = config.kv_channels
|
self.kv_channels = config.kv_channels
|
||||||
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
self.encoder = GLMTransformer(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.encoder")
|
||||||
|
|
||||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.output_layer")
|
||||||
|
|
||||||
vision_config_flag = getattr(config, 'vision_config', None)
|
vision_config_flag = getattr(config, 'vision_config', None)
|
||||||
if vision_config_flag is not None:
|
if vision_config_flag is not None:
|
||||||
self.vision_config = Namespace(**config.vision_config)
|
self.vision_config = Namespace(**config.vision_config)
|
||||||
self.vision = EVA2CLIPModel(self.config, quant_config)
|
self.vision = EVA2CLIPModel(self.config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.vision")
|
||||||
else:
|
else:
|
||||||
self.vision = None
|
self.vision = None
|
||||||
|
|
||||||
|
|||||||
@ -120,6 +120,7 @@ class CohereAttention(nn.Module):
|
|||||||
config: CohereConfig,
|
config: CohereConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -175,7 +176,8 @@ class CohereAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
@ -215,13 +217,15 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: CohereConfig,
|
config: CohereConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = CohereAttention(config,
|
self.self_attn = CohereAttention(config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
|
|
||||||
self.mlp = CohereMLP(config, quant_config=quant_config)
|
self.mlp = CohereMLP(config, quant_config=quant_config)
|
||||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||||
@ -271,8 +275,8 @@ class CohereModel(nn.Module):
|
|||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: CohereDecoderLayer(config, cache_config,
|
lambda prefix: CohereDecoderLayer(
|
||||||
quant_config),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
|
|||||||
@ -154,6 +154,7 @@ class DbrxAttention(nn.Module):
|
|||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@ -208,7 +209,8 @@ class DbrxAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -234,10 +236,14 @@ class DbrxFusedNormAttention(nn.Module):
|
|||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
self.attn = DbrxAttention(config, cache_config, quant_config)
|
self.attn = DbrxAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
self.norm_1 = nn.LayerNorm(self.d_model)
|
self.norm_1 = nn.LayerNorm(self.d_model)
|
||||||
self.norm_2 = nn.LayerNorm(self.d_model)
|
self.norm_2 = nn.LayerNorm(self.d_model)
|
||||||
|
|
||||||
@ -269,10 +275,14 @@ class DbrxBlock(nn.Module):
|
|||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
|
self.norm_attn_norm = DbrxFusedNormAttention(
|
||||||
quant_config)
|
config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.norm_attn_norm")
|
||||||
self.ffn = DbrxMoE(config, quant_config)
|
self.ffn = DbrxMoE(config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -308,7 +318,8 @@ class DbrxModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.blocks = make_layers(
|
self.start_layer, self.end_layer, self.blocks = make_layers(
|
||||||
config.n_layers,
|
config.n_layers,
|
||||||
lambda prefix: DbrxBlock(config, cache_config, quant_config),
|
lambda prefix: DbrxBlock(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.blocks",
|
prefix=f"{prefix}.blocks",
|
||||||
)
|
)
|
||||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||||
|
|||||||
@ -184,6 +184,7 @@ class DeepseekAttention(nn.Module):
|
|||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -236,7 +237,8 @@ class DeepseekAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -261,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -277,6 +280,7 @@ class DeepseekDecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
if (config.n_routed_experts is not None
|
if (config.n_routed_experts is not None
|
||||||
and layer_idx >= config.first_k_dense_replace
|
and layer_idx >= config.first_k_dense_replace
|
||||||
@ -346,7 +350,8 @@ class DeepseekModel(nn.Module):
|
|||||||
lambda prefix: DeepseekDecoderLayer(config,
|
lambda prefix: DeepseekDecoderLayer(config,
|
||||||
int(prefix.split(".")[-1]),
|
int(prefix.split(".")[-1]),
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config),
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
|
|||||||
@ -268,7 +268,8 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_local_heads,
|
num_kv_heads=self.num_local_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -174,6 +174,7 @@ class ExaoneAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -219,7 +220,7 @@ class ExaoneBlockAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
prefix=prefix,
|
prefix=f"{prefix}.attention",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -84,6 +84,7 @@ class FalconAttention(nn.Module):
|
|||||||
config: FalconConfig,
|
config: FalconConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -158,7 +159,8 @@ class FalconAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.inv_norm_factor,
|
self.inv_norm_factor,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
elif self.use_alibi:
|
elif self.use_alibi:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
head_start = tp_rank * self.num_heads
|
head_start = tp_rank * self.num_heads
|
||||||
@ -171,14 +173,16 @@ class FalconAttention(nn.Module):
|
|||||||
self.inv_norm_factor,
|
self.inv_norm_factor,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
else:
|
else:
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.inv_norm_factor,
|
scale=self.inv_norm_factor,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -241,12 +245,16 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
config: FalconConfig,
|
config: FalconConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.self_attention = FalconAttention(config, cache_config,
|
self.self_attention = FalconAttention(
|
||||||
quant_config)
|
config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.self_attention")
|
||||||
self.mlp = FalconMLP(config, quant_config)
|
self.mlp = FalconMLP(config, quant_config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@ -357,8 +365,8 @@ class FalconModel(nn.Module):
|
|||||||
# Transformer blocks
|
# Transformer blocks
|
||||||
self.start_layer, self.end_layer, self.h = make_layers(
|
self.start_layer, self.end_layer, self.h = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: FalconDecoderLayer(config, cache_config,
|
lambda prefix: FalconDecoderLayer(
|
||||||
quant_config),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.h")
|
prefix=f"{prefix}.h")
|
||||||
|
|
||||||
# Final Layer Norm
|
# Final Layer Norm
|
||||||
|
|||||||
@ -35,10 +35,12 @@ class Florence2LanguageModel(nn.Module):
|
|||||||
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
|
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
|
||||||
self.encoder = BartEncoder(config,
|
self.encoder = BartEncoder(config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.encoder")
|
||||||
self.decoder = BartDecoder(config,
|
self.decoder = BartDecoder(config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.decoder")
|
||||||
|
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.encoder.embed_tokens.weight = self.shared.weight
|
self.encoder.embed_tokens.weight = self.shared.weight
|
||||||
@ -99,7 +101,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = Florence2LanguageModel(vllm_config=vllm_config,
|
self.model = Florence2LanguageModel(vllm_config=vllm_config,
|
||||||
prefix=prefix)
|
prefix=f"{prefix}.model")
|
||||||
embed_scale = math.sqrt(
|
embed_scale = math.sqrt(
|
||||||
config.d_model) if config.scale_embedding else 1.0
|
config.d_model) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
@ -198,7 +200,7 @@ class Florence2ForConditionalGeneration(nn.Module):
|
|||||||
# TODO(Isotr0py): Add vision backbone
|
# TODO(Isotr0py): Add vision backbone
|
||||||
self.language_model = Florence2LanguageForConditionalGeneration(
|
self.language_model = Florence2LanguageForConditionalGeneration(
|
||||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||||
prefix=prefix,
|
prefix=f"{prefix}.language_model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -174,7 +174,8 @@ class GemmaAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -95,7 +95,8 @@ class Gemma2Attention(nn.Module):
|
|||||||
rope_theta: float,
|
rope_theta: float,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
attn_logits_soft_cap: Optional[float] = None) -> None:
|
attn_logits_soft_cap: Optional[float] = None,
|
||||||
|
prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -154,7 +155,8 @@ class Gemma2Attention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
logits_soft_cap=attn_logits_soft_cap)
|
logits_soft_cap=attn_logits_soft_cap,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -179,6 +181,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
config: Gemma2Config,
|
config: Gemma2Config,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -194,6 +197,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
attn_logits_soft_cap=config.attn_logit_softcapping,
|
attn_logits_soft_cap=config.attn_logit_softcapping,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.mlp = Gemma2MLP(
|
self.mlp = Gemma2MLP(
|
||||||
@ -257,8 +261,11 @@ class Gemma2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
|
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
|
||||||
-1]), config, cache_config, quant_config),
|
config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
|||||||
@ -56,6 +56,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = '',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -135,11 +136,14 @@ class TransformerLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = '',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_layernorm = LayerNorm(config.hidden_size,
|
self.input_layernorm = LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.attention = Attention(config, quant_config=quant_config)
|
self.attention = Attention(config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attention")
|
||||||
self.mlp = MLP(config, quant_config=quant_config)
|
self.mlp = MLP(config, quant_config=quant_config)
|
||||||
self.post_attention_layernorm = LayerNorm(config.hidden_size,
|
self.post_attention_layernorm = LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
@ -161,11 +165,14 @@ class Transformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = '',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
TransformerLayer(config, quant_config=quant_config)
|
TransformerLayer(config,
|
||||||
for _ in range(config.num_hidden_layers)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.layer.{layer_idx}")
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
@ -252,12 +259,14 @@ class EVA2CLIPModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = '',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
vision_config = Namespace(**config.vision_config)
|
vision_config = Namespace(**config.vision_config)
|
||||||
self.patch_embedding = PatchEmbedding(vision_config)
|
self.patch_embedding = PatchEmbedding(vision_config)
|
||||||
self.transformer = Transformer(vision_config,
|
self.transformer = Transformer(vision_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.transformer")
|
||||||
self.linear_proj = GLU(config,
|
self.linear_proj = GLU(config,
|
||||||
in_features=config.hidden_size,
|
in_features=config.hidden_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|||||||
@ -84,7 +84,8 @@ class GPT2Attention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -52,6 +52,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
config: GPTBigCodeConfig,
|
config: GPTBigCodeConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -92,7 +93,8 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -151,6 +153,7 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
config: GPTBigCodeConfig,
|
config: GPTBigCodeConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
@ -158,7 +161,10 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
hidden_size)
|
hidden_size)
|
||||||
|
|
||||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
|
self.attn = GPTBigCodeAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||||
|
|
||||||
@ -210,7 +216,8 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||||
self.start_layer, self.end_layer, self.h = make_layers(
|
self.start_layer, self.end_layer, self.h = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
|
lambda prefix: GPTBigCodeBlock(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.h",
|
prefix=f"{prefix}.h",
|
||||||
)
|
)
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|||||||
@ -53,6 +53,7 @@ class GPTJAttention(nn.Module):
|
|||||||
config: GPTJConfig,
|
config: GPTJConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.total_num_heads = config.num_attention_heads
|
self.total_num_heads = config.num_attention_heads
|
||||||
@ -94,7 +95,8 @@ class GPTJAttention(nn.Module):
|
|||||||
self.head_size,
|
self.head_size,
|
||||||
scaling,
|
scaling,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -147,12 +149,16 @@ class GPTJBlock(nn.Module):
|
|||||||
config: GPTJConfig,
|
config: GPTJConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = (4 * config.n_embd
|
inner_dim = (4 * config.n_embd
|
||||||
if config.n_inner is None else config.n_inner)
|
if config.n_inner is None else config.n_inner)
|
||||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
self.attn = GPTJAttention(config, cache_config, quant_config)
|
self.attn = GPTJAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -193,7 +199,8 @@ class GPTJModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.h = make_layers(
|
self.start_layer, self.end_layer, self.h = make_layers(
|
||||||
config.n_layer,
|
config.n_layer,
|
||||||
lambda prefix: GPTJBlock(config, cache_config, quant_config),
|
lambda prefix: GPTJBlock(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.h",
|
prefix=f"{prefix}.h",
|
||||||
)
|
)
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|||||||
@ -52,6 +52,7 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
config: GPTNeoXConfig,
|
config: GPTNeoXConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.total_num_heads = config.num_attention_heads
|
self.total_num_heads = config.num_attention_heads
|
||||||
@ -94,7 +95,8 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
self.head_size,
|
self.head_size,
|
||||||
scaling,
|
scaling,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -145,6 +147,7 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
config: GPTNeoXConfig,
|
config: GPTNeoXConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
@ -152,7 +155,10 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
|
self.attention = GPTNeoXAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.attention")
|
||||||
self.mlp = GPTNeoXMLP(config, quant_config)
|
self.mlp = GPTNeoXMLP(config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -205,7 +211,8 @@ class GPTNeoXModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
|
lambda prefix: GPTNeoXLayer(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||||
|
|||||||
@ -161,7 +161,8 @@ class GraniteAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -164,7 +164,8 @@ class GraniteMoeAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -250,7 +250,12 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class InternLM2Model(nn.Module):
|
class InternLM2Model(nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
@ -266,7 +271,7 @@ class InternLM2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: InternLMDecoderLayer(
|
lambda prefix: layer_type(
|
||||||
config, cache_config, quant_config, prefix=prefix),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -316,14 +321,18 @@ class InternLM2Model(nn.Module):
|
|||||||
|
|
||||||
class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
model_type: Type[InternLM2Model] = InternLM2Model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = InternLM2Model(vllm_config=vllm_config,
|
self.model = model_type(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
self.output = ParallelLMHead(config.vocab_size,
|
self.output = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
|||||||
@ -14,8 +14,6 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention,
|
|||||||
InternLM2MLP, InternLM2Model)
|
InternLM2MLP, InternLM2Model)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .utils import make_layers, maybe_prefix
|
|
||||||
|
|
||||||
|
|
||||||
class InternLM2VEDecoderLayer(nn.Module):
|
class InternLM2VEDecoderLayer(nn.Module):
|
||||||
|
|
||||||
@ -105,17 +103,9 @@ class InternLM2VEDecoderLayer(nn.Module):
|
|||||||
class InternLM2VEModel(InternLM2Model):
|
class InternLM2VEModel(InternLM2Model):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
config = vllm_config.model_config.hf_config
|
layer_type=InternLM2VEDecoderLayer)
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
||||||
config.num_hidden_layers,
|
|
||||||
lambda prefix: InternLM2VEDecoderLayer(
|
|
||||||
config, cache_config, quant_config, prefix=prefix),
|
|
||||||
prefix=f"{prefix}.layers")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -159,7 +149,6 @@ class InternLM2VEModel(InternLM2Model):
|
|||||||
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
self.model = InternLM2VEModel(vllm_config=vllm_config,
|
model_type=InternLM2VEModel)
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
|
||||||
|
|||||||
@ -76,6 +76,7 @@ class JAISAttention(nn.Module):
|
|||||||
config: JAISConfig,
|
config: JAISConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -114,7 +115,8 @@ class JAISAttention(nn.Module):
|
|||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -178,6 +180,7 @@ class JAISBlock(nn.Module):
|
|||||||
config: JAISConfig,
|
config: JAISConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
@ -185,7 +188,10 @@ class JAISBlock(nn.Module):
|
|||||||
hidden_size)
|
hidden_size)
|
||||||
|
|
||||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.attn = JAISAttention(config, cache_config, quant_config)
|
self.attn = JAISAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
||||||
|
|
||||||
@ -241,7 +247,8 @@ class JAISModel(nn.Module):
|
|||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: JAISBlock(config=config,
|
lambda prefix: JAISBlock(config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config),
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.h",
|
prefix=f"{prefix}.h",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -102,7 +102,8 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
config: JambaConfig,
|
config: JambaConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.mamba = MambaMixer(hidden_size= config.hidden_size,
|
self.mamba = MambaMixer(hidden_size= config.hidden_size,
|
||||||
@ -157,6 +158,7 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -198,6 +200,7 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
num_experts = config.layers_num_experts[layer_idx]
|
num_experts = config.layers_num_experts[layer_idx]
|
||||||
@ -287,7 +290,8 @@ class JambaModel(nn.Module):
|
|||||||
layer_class(config,
|
layer_class(config,
|
||||||
layer_idx=i,
|
layer_idx=i,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config))
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.layers.{i}"))
|
||||||
self.layers = nn.ModuleList(decoder_layers)
|
self.layers = nn.ModuleList(decoder_layers)
|
||||||
self.final_layernorm = RMSNorm(config.hidden_size,
|
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -174,6 +174,7 @@ class LlamaAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -192,6 +192,7 @@ class MiniCPMAttention(nn.Module):
|
|||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -246,7 +247,8 @@ class MiniCPMAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -273,6 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -283,6 +286,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
self.max_position_embeddings = getattr(config,
|
self.max_position_embeddings = getattr(config,
|
||||||
"max_position_embeddings", 8192)
|
"max_position_embeddings", 8192)
|
||||||
|
self.prefix = prefix
|
||||||
self._init_attn_block()
|
self._init_attn_block()
|
||||||
self._init_ffn_block()
|
self._init_ffn_block()
|
||||||
|
|
||||||
@ -298,6 +302,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
cache_config=self.cache_config,
|
cache_config=self.cache_config,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
|
prefix=f"{self.prefix}.self_attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_ffn_block(self):
|
def _init_ffn_block(self):
|
||||||
@ -388,8 +393,8 @@ class MiniCPMModel(nn.Module):
|
|||||||
):
|
):
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
|
lambda prefix: MiniCPMDecoderLayer(
|
||||||
quant_config),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -60,6 +60,7 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -119,7 +120,8 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_local_heads,
|
num_kv_heads=self.num_local_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -195,6 +197,7 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
|
|||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
cache_config=self.cache_config,
|
cache_config=self.cache_config,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
|
prefix=f"{self.prefix}.self_attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -209,8 +212,8 @@ class MiniCPM3Model(MiniCPMModel):
|
|||||||
):
|
):
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
|
lambda prefix: MiniCPM3DecoderLayer(
|
||||||
quant_config),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -166,7 +166,8 @@ class MixtralAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -170,6 +170,7 @@ class MixtralAttention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -219,7 +220,8 @@ class MixtralAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -243,6 +245,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -255,7 +258,9 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
self.block_sparse_moe = MixtralMoE(config=config,
|
self.block_sparse_moe = MixtralMoE(config=config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
@ -311,7 +316,8 @@ class MixtralModel(nn.Module):
|
|||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: MixtralDecoderLayer(
|
lambda prefix: MixtralDecoderLayer(
|
||||||
config, cache_config, quant_config=quant_config),
|
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||||
|
),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
|
|||||||
@ -370,6 +370,7 @@ class MolmoAttention(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -427,7 +428,8 @@ class MolmoAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
# Attention output projection.
|
# Attention output projection.
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
@ -517,10 +519,14 @@ class MolmoDecoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Attention block.
|
# Attention block.
|
||||||
self.self_attn = MolmoAttention(config, cache_config, quant_config)
|
self.self_attn = MolmoAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
|
|
||||||
# MLP block.
|
# MLP block.
|
||||||
self.mlp = MolmoMLP(config, quant_config=quant_config)
|
self.mlp = MolmoMLP(config, quant_config=quant_config)
|
||||||
@ -738,7 +744,8 @@ class MolmoModel(nn.Module):
|
|||||||
else MolmoDecoderLayer
|
else MolmoDecoderLayer
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: decoder_layer(config, cache_config, quant_config),
|
lambda prefix: decoder_layer(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class MPTAttention(nn.Module):
|
|||||||
config: MPTConfig,
|
config: MPTConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@ -115,7 +116,8 @@ class MPTAttention(nn.Module):
|
|||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -176,11 +178,15 @@ class MPTBlock(nn.Module):
|
|||||||
config: MPTConfig,
|
config: MPTConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.d_model
|
hidden_size = config.d_model
|
||||||
self.norm_1 = nn.LayerNorm(hidden_size)
|
self.norm_1 = nn.LayerNorm(hidden_size)
|
||||||
self.attn = MPTAttention(config, cache_config, quant_config)
|
self.attn = MPTAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||||
self.ffn = MPTMLP(config, quant_config)
|
self.ffn = MPTMLP(config, quant_config)
|
||||||
|
|
||||||
@ -224,7 +230,8 @@ class MPTModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.blocks = make_layers(
|
self.start_layer, self.end_layer, self.blocks = make_layers(
|
||||||
config.n_layers,
|
config.n_layers,
|
||||||
lambda prefix: MPTBlock(config, cache_config, quant_config),
|
lambda prefix: MPTBlock(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.blocks")
|
prefix=f"{prefix}.blocks")
|
||||||
self.norm_f = nn.LayerNorm(config.d_model)
|
self.norm_f = nn.LayerNorm(config.d_model)
|
||||||
if config.no_bias:
|
if config.no_bias:
|
||||||
|
|||||||
@ -195,7 +195,8 @@ class NemotronAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -62,6 +62,7 @@ class OlmoAttention(nn.Module):
|
|||||||
config: OlmoConfig,
|
config: OlmoConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -101,7 +102,8 @@ class OlmoAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scaling,
|
scale=self.scaling,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
# Attention output projection.
|
# Attention output projection.
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: OlmoConfig,
|
config: OlmoConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Attention block.
|
# Attention block.
|
||||||
self.self_attn = OlmoAttention(config, cache_config, quant_config)
|
self.self_attn = OlmoAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
|
|
||||||
# MLP block.
|
# MLP block.
|
||||||
self.mlp = OlmoMLP(config, quant_config)
|
self.mlp = OlmoMLP(config, quant_config)
|
||||||
@ -238,8 +244,8 @@ class OlmoModel(nn.Module):
|
|||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config
|
lambda prefix: OlmoDecoderLayer(
|
||||||
),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = nn.LayerNorm(config.hidden_size,
|
self.norm = nn.LayerNorm(config.hidden_size,
|
||||||
elementwise_affine=False,
|
elementwise_affine=False,
|
||||||
|
|||||||
@ -102,6 +102,7 @@ class OlmoeAttention(nn.Module):
|
|||||||
max_position_embeddings: int = 4096,
|
max_position_embeddings: int = 4096,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -156,7 +157,8 @@ class OlmoeAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -182,6 +184,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -199,6 +202,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = OlmoeMoE(
|
self.mlp = OlmoeMoE(
|
||||||
@ -260,8 +264,11 @@ class OlmoeModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: OlmoeDecoderLayer(config, int(
|
lambda prefix: OlmoeDecoderLayer(config,
|
||||||
prefix.split(".")[-1]), cache_config, quant_config),
|
int(prefix.split(".")[-1]),
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
|||||||
@ -75,6 +75,7 @@ class OrionAttention(nn.Module):
|
|||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -126,7 +127,8 @@ class OrionAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -150,6 +152,7 @@ class OrionDecoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -166,6 +169,7 @@ class OrionDecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.mlp = OrionMLP(
|
self.mlp = OrionMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -226,10 +230,7 @@ class OrionModel(nn.Module):
|
|||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: OrionDecoderLayer(
|
lambda prefix: OrionDecoderLayer(
|
||||||
config,
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
cache_config,
|
|
||||||
quant_config,
|
|
||||||
),
|
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
|
|||||||
@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PersimmonConfig,
|
config: PersimmonConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
tensor_parallel_world_size = get_tensor_model_parallel_world_size()
|
tensor_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||||
@ -122,7 +123,8 @@ class PersimmonAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scaling,
|
scale=self.scaling,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
|
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
|
||||||
@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PersimmonConfig,
|
config: PersimmonConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = PersimmonAttention(config=config,
|
self.self_attn = PersimmonAttention(config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
self.mlp = PersimmonMLP(config, quant_config=quant_config)
|
self.mlp = PersimmonMLP(config, quant_config=quant_config)
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
@ -226,8 +230,8 @@ class PersimmonModel(nn.Module):
|
|||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: PersimmonDecoderLayer(config, cache_config,
|
lambda prefix: PersimmonDecoderLayer(
|
||||||
quant_config),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
|
|||||||
@ -69,7 +69,8 @@ class PhiAttention(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PhiConfig,
|
config: PhiConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.total_num_heads = config.num_attention_heads
|
self.total_num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -116,7 +117,8 @@ class PhiAttention(nn.Module):
|
|||||||
self.head_size,
|
self.head_size,
|
||||||
scaling,
|
scaling,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -167,11 +169,15 @@ class PhiLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PhiConfig,
|
config: PhiConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.self_attn = PhiAttention(config, cache_config, quant_config)
|
self.self_attn = PhiAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
self.mlp = PhiMLP(config, quant_config)
|
self.mlp = PhiMLP(config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -210,7 +216,8 @@ class PhiModel(nn.Module):
|
|||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: PhiLayer(config, cache_config, quant_config),
|
lambda prefix: PhiLayer(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
|
|||||||
@ -117,6 +117,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -214,15 +215,14 @@ class Phi3SmallSelfAttention(nn.Module):
|
|||||||
"homo_head": self.homo_heads
|
"homo_head": self.homo_heads
|
||||||
}
|
}
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(self.num_heads_per_partition,
|
||||||
self.num_heads_per_partition,
|
self.head_dim,
|
||||||
self.head_dim,
|
self.scale,
|
||||||
self.scale,
|
num_kv_heads=self.num_kv_heads_per_partion,
|
||||||
num_kv_heads=self.num_kv_heads_per_partion,
|
cache_config=cache_config,
|
||||||
cache_config=cache_config,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
blocksparse_params=bs_params,
|
||||||
blocksparse_params=bs_params,
|
prefix=f"{prefix}.attn")
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -259,13 +259,15 @@ class Phi3SmallDecoderLayer(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = Phi3SmallSelfAttention(config,
|
self.self_attn = Phi3SmallSelfAttention(config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
self.mlp = Phi3SmallMLP(config, quant_config)
|
self.mlp = Phi3SmallMLP(config, quant_config)
|
||||||
|
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
@ -315,7 +317,9 @@ class Phi3SmallModel(nn.Module):
|
|||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: Phi3SmallDecoderLayer(config,
|
lambda prefix: Phi3SmallDecoderLayer(config,
|
||||||
int(prefix.split('.')[-1]),
|
int(prefix.split('.')[-1]),
|
||||||
cache_config, quant_config),
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
|||||||
@ -294,6 +294,7 @@ class PhiMoEAttention(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
rope_scaling: Optional[dict] = None,
|
rope_scaling: Optional[dict] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -347,6 +348,7 @@ class PhiMoEAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -371,6 +373,7 @@ class PhiMoEDecoderLayer(nn.Module):
|
|||||||
config: PhiMoEConfig,
|
config: PhiMoEConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -385,6 +388,7 @@ class PhiMoEDecoderLayer(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
rope_scaling=config.rope_scaling,
|
rope_scaling=config.rope_scaling,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.block_sparse_moe = PhiMoE(
|
self.block_sparse_moe = PhiMoE(
|
||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
@ -454,8 +458,8 @@ class PhiMoEModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: PhiMoEDecoderLayer(config, cache_config,
|
lambda prefix: PhiMoEDecoderLayer(
|
||||||
quant_config),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = nn.LayerNorm(config.hidden_size,
|
self.norm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
|
|||||||
@ -442,6 +442,7 @@ class QWenAttention(nn.Module):
|
|||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -478,7 +479,8 @@ class QWenAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -502,6 +504,7 @@ class QWenBlock(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
@ -514,7 +517,8 @@ class QWenBlock(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
@ -568,7 +572,8 @@ class QWenModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.h = make_layers(
|
self.start_layer, self.end_layer, self.h = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: QWenBlock(config, cache_config, quant_config),
|
lambda prefix: QWenBlock(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.h")
|
prefix=f"{prefix}.h")
|
||||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
|
|||||||
@ -168,6 +168,7 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -220,7 +221,8 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -245,6 +247,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||||
@ -336,7 +340,8 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
layer_idx=int(
|
layer_idx=int(
|
||||||
prefix.split(".")[-1]),
|
prefix.split(".")[-1]),
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config),
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -167,6 +167,7 @@ class SolarAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -77,7 +77,8 @@ class StablelmAttention(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -131,7 +132,8 @@ class StablelmAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_key_value_heads,
|
num_kv_heads=self.num_key_value_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -155,9 +157,13 @@ class StablelmDecoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = StablelmAttention(config, cache_config, quant_config)
|
self.self_attn = StablelmAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
self.mlp = StablelmMLP(config, quant_config)
|
self.mlp = StablelmMLP(config, quant_config)
|
||||||
norm_eps = getattr(config, "norm_eps",
|
norm_eps = getattr(config, "norm_eps",
|
||||||
getattr(config, "layer_norm_eps", 1e-05))
|
getattr(config, "layer_norm_eps", 1e-05))
|
||||||
@ -207,8 +213,8 @@ class StableLMEpochModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: StablelmDecoderLayer(config, cache_config,
|
lambda prefix: StablelmDecoderLayer(
|
||||||
quant_config),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
norm_eps = getattr(config, "norm_eps",
|
norm_eps = getattr(config, "norm_eps",
|
||||||
|
|||||||
@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Starcoder2Config,
|
config: Starcoder2Config,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Starcoder2Config,
|
config: Starcoder2Config,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = Starcoder2Attention(config,
|
self.self_attn = Starcoder2Attention(config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
|
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.norm_epsilon)
|
eps=config.norm_epsilon)
|
||||||
@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
|
|||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: Starcoder2DecoderLayer(
|
lambda prefix: Starcoder2DecoderLayer(
|
||||||
config, cache_config, quant_config=quant_config),
|
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||||
|
),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
||||||
|
|||||||
@ -93,6 +93,7 @@ class XverseAttention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -138,7 +139,8 @@ class XverseAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=getattr(config, "bias", False),
|
bias=getattr(config, "bias", False),
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.mlp = XverseMLP(
|
self.mlp = XverseMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -243,8 +247,8 @@ class XverseModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: XverseDecoderLayer(config, cache_config,
|
lambda prefix: XverseDecoderLayer(
|
||||||
quant_config),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -20,6 +20,7 @@ logger = init_logger(__name__)
|
|||||||
class CpuPlatform(Platform):
|
class CpuPlatform(Platform):
|
||||||
_enum = PlatformEnum.CPU
|
_enum = PlatformEnum.CPU
|
||||||
device_type: str = "cpu"
|
device_type: str = "cpu"
|
||||||
|
dispatch_key: str = "CPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
|
|||||||
@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
|
|||||||
class CudaPlatform(Platform):
|
class CudaPlatform(Platform):
|
||||||
_enum = PlatformEnum.CUDA
|
_enum = PlatformEnum.CUDA
|
||||||
device_type: str = "cuda"
|
device_type: str = "cuda"
|
||||||
|
dispatch_key: str = "CUDA"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||||
|
|||||||
@ -13,6 +13,7 @@ else:
|
|||||||
class HpuPlatform(Platform):
|
class HpuPlatform(Platform):
|
||||||
_enum = PlatformEnum.HPU
|
_enum = PlatformEnum.HPU
|
||||||
device_type: str = "hpu"
|
device_type: str = "hpu"
|
||||||
|
dispatch_key: str = "HPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||||
|
|||||||
@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
|
|||||||
class Platform:
|
class Platform:
|
||||||
_enum: PlatformEnum
|
_enum: PlatformEnum
|
||||||
device_type: str
|
device_type: str
|
||||||
|
# available dispatch keys:
|
||||||
|
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||||
|
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||||
|
dispatch_key: str = "CPU"
|
||||||
|
|
||||||
def is_cuda(self) -> bool:
|
def is_cuda(self) -> bool:
|
||||||
return self._enum == PlatformEnum.CUDA
|
return self._enum == PlatformEnum.CUDA
|
||||||
|
|||||||
@ -18,6 +18,7 @@ logger = init_logger(__name__)
|
|||||||
class OpenVinoPlatform(Platform):
|
class OpenVinoPlatform(Platform):
|
||||||
_enum = PlatformEnum.OPENVINO
|
_enum = PlatformEnum.OPENVINO
|
||||||
device_type: str = "openvino"
|
device_type: str = "openvino"
|
||||||
|
dispatch_key: str = "CPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||||
|
|||||||
@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
|||||||
class RocmPlatform(Platform):
|
class RocmPlatform(Platform):
|
||||||
_enum = PlatformEnum.ROCM
|
_enum = PlatformEnum.ROCM
|
||||||
device_type: str = "cuda"
|
device_type: str = "cuda"
|
||||||
|
dispatch_key: str = "CUDA"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||||
|
|||||||
@ -17,6 +17,7 @@ logger = init_logger(__name__)
|
|||||||
class TpuPlatform(Platform):
|
class TpuPlatform(Platform):
|
||||||
_enum = PlatformEnum.TPU
|
_enum = PlatformEnum.TPU
|
||||||
device_type: str = "tpu"
|
device_type: str = "tpu"
|
||||||
|
dispatch_key: str = "XLA"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||||
|
|||||||
@ -17,6 +17,7 @@ logger = init_logger(__name__)
|
|||||||
class XPUPlatform(Platform):
|
class XPUPlatform(Platform):
|
||||||
_enum = PlatformEnum.XPU
|
_enum = PlatformEnum.XPU
|
||||||
device_type: str = "xpu"
|
device_type: str = "xpu"
|
||||||
|
dispatch_key: str = "XPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||||
|
|||||||
@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
if previous_hidden_states is not None else {}
|
if previous_hidden_states is not None else {}
|
||||||
|
|
||||||
# Run model
|
# Run model
|
||||||
with set_forward_context(model_input.attn_metadata):
|
with set_forward_context(model_input.attn_metadata,
|
||||||
|
self.vllm_config):
|
||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
|||||||
@ -1573,6 +1573,7 @@ def direct_register_custom_op(
|
|||||||
mutates_args: List[str],
|
mutates_args: List[str],
|
||||||
fake_impl: Optional[Callable] = None,
|
fake_impl: Optional[Callable] = None,
|
||||||
target_lib: Optional[Library] = None,
|
target_lib: Optional[Library] = None,
|
||||||
|
dispatch_key: str = "CUDA",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
`torch.library.custom_op` can have significant overhead because it
|
`torch.library.custom_op` can have significant overhead because it
|
||||||
@ -1601,7 +1602,7 @@ def direct_register_custom_op(
|
|||||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||||
my_lib = target_lib or vllm_lib
|
my_lib = target_lib or vllm_lib
|
||||||
my_lib.define(op_name + schema_str)
|
my_lib.define(op_name + schema_str)
|
||||||
my_lib.impl(op_name, op_func, "CUDA")
|
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||||
if fake_impl is not None:
|
if fake_impl is not None:
|
||||||
my_lib._register_fake(op_name, fake_impl)
|
my_lib._register_fake(op_name, fake_impl)
|
||||||
|
|
||||||
|
|||||||
@ -173,7 +173,8 @@ def unified_v1_flash_attention(
|
|||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_metadata = get_forward_context()
|
context = get_forward_context()
|
||||||
|
current_metadata = context.dynamic_forward_context
|
||||||
if current_metadata is None:
|
if current_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return
|
return
|
||||||
|
|||||||
@ -447,7 +447,7 @@ class GPUModelRunner:
|
|||||||
|
|
||||||
# Run the decoder.
|
# Run the decoder.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
with set_forward_context(attn_metadata):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
@ -523,7 +523,7 @@ class GPUModelRunner:
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
with set_forward_context(None):
|
with set_forward_context(None, self.vllm_config):
|
||||||
hidden_states = model(
|
hidden_states = model(
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
positions=self.positions[:num_tokens],
|
positions=self.positions[:num_tokens],
|
||||||
|
|||||||
@ -97,7 +97,7 @@ class EmbeddingModelRunner(
|
|||||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||||
model_forward_start.record()
|
model_forward_start.record()
|
||||||
|
|
||||||
with set_forward_context(model_input.attn_metadata):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
|||||||
@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
} if self.has_inner_state else {}
|
} if self.has_inner_state else {}
|
||||||
|
|
||||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
with set_forward_context(model_input.attn_metadata):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
|||||||
@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self._update_inputs_to_capture_for_enc_dec_model(
|
self._update_inputs_to_capture_for_enc_dec_model(
|
||||||
capture_inputs)
|
capture_inputs)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
graph_runner.capture(**capture_inputs)
|
graph_runner.capture(**capture_inputs)
|
||||||
self.graph_memory_pool = graph_runner.graph.pool()
|
self.graph_memory_pool = graph_runner.graph.pool()
|
||||||
self.graph_runners[virtual_engine][batch_size] = (
|
self.graph_runners[virtual_engine][batch_size] = (
|
||||||
@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||||
model_forward_start.record()
|
model_forward_start.record()
|
||||||
|
|
||||||
with set_forward_context(model_input.attn_metadata):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user