[torch.compile] support all attention backends (#10558)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-22 14:04:42 -08:00 committed by GitHub
parent db100c5cde
commit eebad39f26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 876 additions and 648 deletions

View File

@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
# List of support backends for encoder/decoder models # List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
@ -594,6 +596,7 @@ def _run_encoder_attention_test(
encoder_test_params: PhaseTestParameters, encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run encoder attention. Run encoder attention.
@ -623,7 +626,7 @@ def _run_encoder_attention_test(
attn_type = AttentionType.ENCODER attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
decoder_test_params: PhaseTestParameters, decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run decoder self-attention test. Run decoder self-attention test.
@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
kv_cache = test_rsrcs.kv_cache kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_test_params: Optional[PhaseTestParameters], cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run encoder/decoder cross-attention test. Run encoder/decoder cross-attention test.
@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
@ -839,7 +844,9 @@ def test_encoder_only(
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
test_rsrcs = _make_test_resources(test_pt) vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used # Construct encoder attention test params (only used
# during prefill) # during prefill)
@ -863,7 +870,8 @@ def test_encoder_only(
test_rsrcs.attn, test_rsrcs.attn,
enc_test_params, enc_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt)) test_pt=test_pt,
vllm_config=vllm_config))
# - Is encoder attention result correct? # - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
test_rsrcs = _make_test_resources(test_pt) vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used # Construct encoder attention test params (only used
# during prefill) # during prefill)
@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params, enc_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is encoder attention result correct? # - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs, test_rsrcs,
prephase_dec_test_params, prephase_dec_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is prefill decoder self-attention correct? # - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params, assert_actual_matches_ideal(prephase_dec_test_params,
@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params, prephase_dec_test_params,
prephase_cross_test_params, prephase_cross_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is prefill encoder/decoder cross-attention correct? # - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params, assert_actual_matches_ideal(prephase_cross_test_params,
@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs, test_rsrcs,
decphase_dec_test_params, decphase_dec_test_params,
decphase_attn_metadata, decphase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is decode-phase decoder self-attention correct? # - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params, assert_actual_matches_ideal(decphase_dec_test_params,
@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
decphase_dec_test_params, decphase_dec_test_params,
None, None,
decphase_attn_metadata, decphase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is decode-phase encoder/decoder cross-attention correct? # - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params, assert_actual_matches_ideal(decphase_cross_test_params,

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar) Tuple, Type, TypeVar)
@ -15,13 +14,19 @@ if TYPE_CHECKING:
ModelRunnerInputBuilderBase) ModelRunnerInputBuilderBase)
class AttentionType(Enum): class AttentionType:
DECODER = auto() # Decoder attention between previous layer Q/K/V """
ENCODER = auto( Attention type.
) # Encoder attention between previous layer Q/K/V for encoder-decoder Use string to be compatible with `torch.compile`.
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V """
ENCODER_DECODER = auto( # Decoder attention between previous layer Q/K/V
) # Attention between dec. Q and enc. K/V for encoder-decoder DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER = "encoder_decoder"
class AttentionBackend(ABC): class AttentionBackend(ABC):
@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T, attn_metadata: T,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.

View File

@ -16,10 +16,8 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, from vllm.utils import async_tensor_h2d, make_tensor_with_pad
make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -668,23 +666,174 @@ class FlashAttentionImpl(AttentionImpl):
"requires setting cross-attention " "requires setting cross-attention "
"metadata attributes.") "metadata attributes.")
output = torch.ops.vllm.unified_flash_attention( num_heads: int = self.num_heads
query, head_size: int = self.head_size
key, num_kv_heads: int = self.num_kv_heads
value, kv_cache_dtype: str = self.kv_cache_dtype
self.num_heads, softmax_scale: float = self.scale
self.head_size, window_size = self.sliding_window
self.num_kv_heads, alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
kv_cache, logits_soft_cap: Optional[float] = self.logits_soft_cap
self.kv_cache_dtype,
k_scale, num_tokens, hidden_size = query.shape
v_scale,
self.scale, # Reshape the query, key, and value tensors.
attn_type.value, query = query.view(-1, num_heads, head_size)
self.sliding_window, if (key is not None) and (value is not None):
self.alibi_slopes, key = key.view(-1, num_kv_heads, head_size)
self.logits_soft_cap, value = value.view(-1, num_kv_heads, head_size)
)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the
# KV cache is already populated with the cross-attention
# tensor. Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
# use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1"
)
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size)
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
return output return output
@ -692,7 +841,7 @@ class FlashAttentionImpl(AttentionImpl):
def _get_query_key_seq_metadata( def _get_query_key_seq_metadata(
attn_metadata, attn_metadata,
is_prompt: bool, is_prompt: bool,
attn_type: AttentionType, attn_type: str,
) -> tuple: ) -> tuple:
""" """
Returns sequence metadata for key and query based on the specified Returns sequence metadata for key and query based on the specified
@ -754,7 +903,7 @@ def _get_query_key_seq_metadata(
raise AttributeError(f"Invalid attention type {str(attn_type)}") raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_causal_option(attn_type: AttentionType) -> bool: def _get_causal_option(attn_type: str) -> bool:
""" """
Determine whether the given attention type is suitable for causal Determine whether the given attention type is suitable for causal
attention mechanisms. attention mechanisms.
@ -770,220 +919,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool:
return not (attn_type == AttentionType.ENCODER return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER) or attn_type == AttentionType.ENCODER_DECODER)
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
# Convert integer attn_type to enum
try:
attn_type = AttentionType(attn_type_int_val)
except ValueError as err:
raise AttributeError(
f"Invalid attention type {str(attn_type_int_val)}") from err
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
if (key is not None) and (value is not None):
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
# use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1")
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size)
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)

View File

@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, make_tensor_with_pad)
get_kv_cache_torch_dtype, make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl):
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
@ -782,174 +781,117 @@ class FlashInferImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"FlashInferImpl") "FlashInferImpl")
return torch.ops.vllm.unified_flash_infer( num_heads: int = self.num_heads
query, head_size: int = self.head_size
key, num_kv_heads: int = self.num_kv_heads
value, kv_cache_dtype: str = self.kv_cache_dtype
self.num_heads, softmax_scale: float = self.scale
self.head_size, window_size = self.sliding_window
self.num_kv_heads, alibi_slopes = self.alibi_slopes
kv_cache, logits_soft_cap = self.logits_soft_cap
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
def unified_flash_infer( if kv_cache.numel() > 0:
query: torch.Tensor, # Use the same reshape and cache kernel as flash attention.
key: torch.Tensor, ops.reshape_and_cache_flash(
value: torch.Tensor, key,
num_heads: int, value,
head_size: int, kv_cache[:, 0],
num_kv_heads: int, kv_cache[:, 1],
kv_cache: torch.Tensor, attn_metadata.slot_mapping.flatten(),
kv_cache_dtype: str, kv_cache_dtype,
k_scale: float, k_scale,
v_scale: float, v_scale,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashInferMetadata)
attn_metadata: FlashInferMetadata = current_metadata
num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous() # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
window_left = window_size[0] if window_size is not None else -1
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
) )
else: # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
assert prefill_meta is not None # to process the cache when the kv_cache_dtype is fp8
assert prefill_meta.prefill_wrapper is not None if kv_cache_dtype.startswith("fp8"):
prefill_output = prefill_meta.prefill_wrapper.forward( torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
query, kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous(
) # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
window_left = window_size[0] if window_size is not None else -1
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
)
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(
decode_query,
kv_cache, kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale, k_scale=k_scale,
v_scale=v_scale, v_scale=v_scale,
window_left=window_left) window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
if prefill_output is None and decode_output is not None: if prefill_output is None and decode_output is not None:
# Decode only batch. # Decode only batch.
output, num_tokens = decode_output, num_decode_tokens output, num_tokens = decode_output, num_decode_tokens
elif decode_output is None and prefill_output is not None: elif decode_output is None and prefill_output is not None:
# Prefill only batch. # Prefill only batch.
output, num_tokens = prefill_output, num_prefill_tokens output, num_tokens = prefill_output, num_prefill_tokens
else: else:
# Chunked prefill batch does not work with speculative decoding in # Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1. # FlashInfer backend, so the query length for decode should be 1.
assert prefill_output is not None assert prefill_output is not None
assert decode_output is not None assert decode_output is not None
assert decode_meta is not None assert decode_meta is not None
assert decode_meta.decode_query_len == 1 assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1) decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0) output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
def unified_flash_infer_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_flash_infer",
op_func=unified_flash_infer,
mutates_args=["kv_cache"],
fake_impl=unified_flash_infer_fake,
)

View File

@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_metadata: HPUAttentionMetadata, attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.

View File

@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.

View File

@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.

View File

@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.

View File

@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_seq_lens( def get_seq_lens(
self, self,
attn_type: AttentionType, attn_type: str,
): ):
''' '''
Extract appropriate sequence lengths from attention metadata Extract appropriate sequence lengths from attention metadata
@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_attn_bias( def get_attn_bias(
self, self,
attn_type: AttentionType, attn_type: str,
) -> Optional[List[torch.Tensor]]: ) -> Optional[List[torch.Tensor]]:
''' '''
Extract appropriate attention bias from attention metadata Extract appropriate attention bias from attention metadata
@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def set_attn_bias( def set_attn_bias(
self, self,
attn_bias: List[torch.Tensor], attn_bias: List[torch.Tensor],
attn_type: AttentionType, attn_type: str,
) -> None: ) -> None:
''' '''
Update appropriate attention bias field of attention metadata, Update appropriate attention bias field of attention metadata,
@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_seq_len_block_table_args( def get_seq_len_block_table_args(
self, self,
attn_type: AttentionType, attn_type: str,
) -> tuple: ) -> tuple:
''' '''
The particular choice of sequence-length- and block-table-related The particular choice of sequence-length- and block-table-related
@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: TorchSDPAMetadata, attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> None: ) -> None:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)

View File

@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata):
def get_seq_len_block_table_args( def get_seq_len_block_table_args(
attn_metadata, attn_metadata,
is_prompt: bool, is_prompt: bool,
attn_type: AttentionType, attn_type: str,
) -> tuple: ) -> tuple:
''' '''
The particular choice of sequence-length- and block-table-related The particular choice of sequence-length- and block-table-related
@ -529,7 +529,7 @@ def get_seq_len_block_table_args(
def get_num_prefill_decode_query_kv_tokens( def get_num_prefill_decode_query_kv_tokens(
attn_metadata, attn_metadata,
attn_type: AttentionType, attn_type: str,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
""" """
Calculate the number of prefill and decode tokens for query, key/value Calculate the number of prefill and decode tokens for query, key/value

View File

@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
def _get_attn_bias( def _get_attn_bias(
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_type: AttentionType, attn_type: str,
) -> Optional[AttentionBias]: ) -> Optional[AttentionBias]:
''' '''
Extract appropriate attention bias from attention metadata Extract appropriate attention bias from attention metadata
@ -314,7 +314,7 @@ def _get_attn_bias(
def _set_attn_bias( def _set_attn_bias(
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_bias: List[Optional[AttentionBias]], attn_bias: List[Optional[AttentionBias]],
attn_type: AttentionType, attn_type: str,
) -> None: ) -> None:
''' '''
Update appropriate attention bias field of attention metadata, Update appropriate attention bias field of attention metadata,
@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt """Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input. tokens are flattened in to `query` input.

View File

@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import direct_register_custom_op
class Attention(nn.Module): class Attention(nn.Module):
@ -86,6 +91,18 @@ class Attention(nn.Module):
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap) blocksparse_params, logits_soft_cap)
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = envs.VLLM_USE_V1 or not (
current_platform.is_cuda_alike() or current_platform.is_cpu())
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
@ -93,17 +110,22 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, if self.use_direct_call:
key, return self.impl.forward(query,
value, key,
kv_cache, value,
attn_metadata, kv_cache,
self._k_scale, attn_metadata,
self._v_scale, self._k_scale,
attn_type=attn_type) self._v_scale,
attn_type=attn_type)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
@ -112,3 +134,44 @@ class Attention(nn.Module):
s += f", scale={self.impl.scale}" # type: ignore s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}" s += f", backend={self.impl.__class__.__name__}"
return s return s
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
backend: str = "" backend: str = ""
custom_ops: List[str] = Field(default_factory=list) custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [ splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention", "vllm.unified_attention",
"vllm.unified_flash_infer",
"vllm.unified_v1_flash_attention", "vllm.unified_v1_flash_attention",
]) ])
@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
@classmethod @classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig": def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config.""" """Parse the CLI value for the compilation config."""
@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter() self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter() self.disabled_custom_ops = Counter()
self.static_forward_context = {}
def init_backend(self) -> Union[str, Callable]: def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:

View File

@ -1,21 +1,38 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from dataclasses import dataclass
from typing import Any, Dict, Optional
_forward_context: Any = None from vllm.config import VllmConfig
def get_forward_context() -> Any: @dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
_forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext:
"""Get the current forward context.""" """Get the current forward context."""
assert _forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return _forward_context return _forward_context
@contextmanager @contextmanager
def set_forward_context(context: Any): def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc.""" can be attention metadata, etc."""
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
_forward_context = context _forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
try: try:
yield yield
finally: finally:

View File

@ -223,6 +223,7 @@ class ArcticAttention(nn.Module):
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -274,7 +275,8 @@ class ArcticAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module):
self.self_attn = ArcticAttention(config, self.self_attn = ArcticAttention(config,
layer_idx, layer_idx,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = ArcticMoE( self.block_sparse_moe = ArcticMoE(
config, config,
layer_id=layer_idx, layer_id=layer_idx,
@ -380,8 +384,11 @@ class ArcticModel(nn.Module):
org_num_embeddings=self.vocab_size) org_num_embeddings=self.vocab_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config, int( lambda prefix: ArcticDecoderLayer(config,
prefix.split(".")[-1]), cache_config, quant_config), int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module):
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
else: else:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: BaiChuanDecoderLayer(config, position_embedding, lambda prefix: BaiChuanDecoderLayer(config,
cache_config, quant_config), position_embedding,
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor: attn_metadata: AttentionMetadata) -> torch.Tensor:
@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor: attn_metadata: AttentionMetadata) -> torch.Tensor:
@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module):
config: BartConfig, config: BartConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module):
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module):
config: BartConfig, config: BartConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module):
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module):
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
config=config, config=config,
prefix=f"{prefix}.encoder_attn",
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
@ -578,7 +591,8 @@ class BartEncoder(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None): embed_tokens: Optional[nn.Embedding] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.cache_config = cache_config self.cache_config = cache_config
@ -599,9 +613,13 @@ class BartEncoder(nn.Module):
config.max_position_embeddings, config.max_position_embeddings,
embed_dim, embed_dim,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[BartEncoderLayer(config,cache_config,quant_config) \ BartEncoderLayer(config,
for _ in range(config.encoder_layers)]) cache_config,
quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.encoder_layers)
])
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
@ -661,6 +679,7 @@ class BartDecoder(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None, embed_tokens: Optional[nn.Embedding] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.cache_config = cache_config self.cache_config = cache_config
@ -683,8 +702,9 @@ class BartDecoder(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[BartDecoderLayer(config,cache_config,quant_config) \ [BartDecoderLayer(config,cache_config,quant_config,
for _ in range(config.decoder_layers)]) prefix=f"{prefix}.layers.{layer_idx}") \
for layer_idx in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layernorm_embedding = nn.LayerNorm(config.d_model)
@ -759,10 +779,12 @@ class BartModel(nn.Module):
self.encoder = BartEncoder(config, self.encoder = BartEncoder(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config, self.decoder = BartDecoder(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,

View File

@ -78,6 +78,7 @@ class BloomAttention(nn.Module):
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -116,7 +117,8 @@ class BloomAttention(nn.Module):
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -168,14 +170,17 @@ class BloomBlock(nn.Module):
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size, self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, cache_config, self.self_attention = BloomAttention(config,
quant_config) cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon) hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config) self.mlp = BloomMLP(config, quant_config)
@ -242,7 +247,8 @@ class BloomModel(nn.Module):
# Transformer blocks # Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: BloomBlock(config, cache_config, quant_config), lambda prefix: BloomBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
# Final Layer Norm # Final Layer Norm

View File

@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor, def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module):
config: ChameleonConfig, config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = ChameleonMLP( self.mlp = ChameleonMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
config: ChameleonConfig, config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = ChameleonMLP( self.mlp = ChameleonMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@ -855,7 +861,8 @@ class ChameleonModel(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: decoder_layer(config=config, lambda prefix: decoder_layer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )

View File

@ -230,6 +230,7 @@ class GLMAttention(nn.Module):
config: ChatGLMConfig, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -285,7 +286,8 @@ class GLMAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -364,6 +366,7 @@ class GLMBlock(nn.Module):
config: ChatGLMConfig, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
@ -377,7 +380,10 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon) eps=config.layernorm_epsilon)
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, cache_config, quant_config) self.self_attention = GLMAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
@ -446,7 +452,8 @@ class GLMTransformer(nn.Module):
# Transformer layers. # Transformer layers.
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
self.num_layers, self.num_layers,
lambda prefix: GLMBlock(config, cache_config, quant_config), lambda prefix: GLMBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
@ -500,16 +507,22 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config) self.encoder = GLMTransformer(config,
cache_config,
quant_config,
prefix=f"{prefix}.encoder")
self.output_layer = ParallelLMHead(config.padded_vocab_size, self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.output_layer")
vision_config_flag = getattr(config, 'vision_config', None) vision_config_flag = getattr(config, 'vision_config', None)
if vision_config_flag is not None: if vision_config_flag is not None:
self.vision_config = Namespace(**config.vision_config) self.vision_config = Namespace(**config.vision_config)
self.vision = EVA2CLIPModel(self.config, quant_config) self.vision = EVA2CLIPModel(self.config,
quant_config,
prefix=f"{prefix}.vision")
else: else:
self.vision = None self.vision = None

View File

@ -120,6 +120,7 @@ class CohereAttention(nn.Module):
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@ -175,7 +176,8 @@ class CohereAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads, self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim), self.head_dim),
@ -215,13 +217,15 @@ class CohereDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, self.self_attn = CohereAttention(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = CohereMLP(config, quant_config=quant_config) self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
@ -271,8 +275,8 @@ class CohereModel(nn.Module):
config.hidden_size) config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: CohereDecoderLayer(config, cache_config, lambda prefix: CohereDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = LayerNorm(param_shape=(config.hidden_size), self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps) eps=config.layer_norm_eps)

View File

@ -154,6 +154,7 @@ class DbrxAttention(nn.Module):
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
@ -208,7 +209,8 @@ class DbrxAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -234,10 +236,14 @@ class DbrxFusedNormAttention(nn.Module):
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
self.attn = DbrxAttention(config, cache_config, quant_config) self.attn = DbrxAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.norm_1 = nn.LayerNorm(self.d_model) self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model)
@ -269,10 +275,14 @@ class DbrxBlock(nn.Module):
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, self.norm_attn_norm = DbrxFusedNormAttention(
quant_config) config,
cache_config,
quant_config,
prefix=f"{prefix}.norm_attn_norm")
self.ffn = DbrxMoE(config, quant_config) self.ffn = DbrxMoE(config, quant_config)
def forward( def forward(
@ -308,7 +318,8 @@ class DbrxModel(nn.Module):
) )
self.start_layer, self.end_layer, self.blocks = make_layers( self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers, config.n_layers,
lambda prefix: DbrxBlock(config, cache_config, quant_config), lambda prefix: DbrxBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks", prefix=f"{prefix}.blocks",
) )
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)

View File

@ -184,6 +184,7 @@ class DeepseekAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -236,7 +237,8 @@ class DeepseekAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -261,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -277,6 +280,7 @@ class DeepseekDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace and layer_idx >= config.first_k_dense_replace
@ -346,7 +350,8 @@ class DeepseekModel(nn.Module):
lambda prefix: DeepseekDecoderLayer(config, lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]), int(prefix.split(".")[-1]),
cache_config, cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (

View File

@ -268,7 +268,8 @@ class DeepseekV2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,

View File

@ -174,6 +174,7 @@ class ExaoneAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
@ -219,7 +220,7 @@ class ExaoneBlockAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=bias, bias=bias,
cache_config=cache_config, cache_config=cache_config,
prefix=prefix, prefix=f"{prefix}.attention",
) )
def forward( def forward(

View File

@ -84,6 +84,7 @@ class FalconAttention(nn.Module):
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
@ -158,7 +159,8 @@ class FalconAttention(nn.Module):
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
elif self.use_alibi: elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads head_start = tp_rank * self.num_heads
@ -171,14 +173,16 @@ class FalconAttention(nn.Module):
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
else: else:
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.inv_norm_factor, scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -241,12 +245,16 @@ class FalconDecoderLayer(nn.Module):
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, cache_config, self.self_attention = FalconAttention(
quant_config) config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.mlp = FalconMLP(config, quant_config) self.mlp = FalconMLP(config, quant_config)
self.config = config self.config = config
@ -357,8 +365,8 @@ class FalconModel(nn.Module):
# Transformer blocks # Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(config, cache_config, lambda prefix: FalconDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
# Final Layer Norm # Final Layer Norm

View File

@ -35,10 +35,12 @@ class Florence2LanguageModel(nn.Module):
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
self.encoder = BartEncoder(config, self.encoder = BartEncoder(config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config, self.decoder = BartDecoder(config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.decoder")
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.encoder.embed_tokens.weight = self.shared.weight self.encoder.embed_tokens.weight = self.shared.weight
@ -99,7 +101,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
self.config = config self.config = config
self.model = Florence2LanguageModel(vllm_config=vllm_config, self.model = Florence2LanguageModel(vllm_config=vllm_config,
prefix=prefix) prefix=f"{prefix}.model")
embed_scale = math.sqrt( embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0 config.d_model) if config.scale_embedding else 1.0
@ -198,7 +200,7 @@ class Florence2ForConditionalGeneration(nn.Module):
# TODO(Isotr0py): Add vision backbone # TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration( self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config), vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix, prefix=f"{prefix}.language_model",
) )
@property @property

View File

@ -174,7 +174,8 @@ class GemmaAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,

View File

@ -95,7 +95,8 @@ class Gemma2Attention(nn.Module):
rope_theta: float, rope_theta: float,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None: attn_logits_soft_cap: Optional[float] = None,
prefix: str = "") -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.config = config self.config = config
@ -154,7 +155,8 @@ class Gemma2Attention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap) logits_soft_cap=attn_logits_soft_cap,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -179,6 +181,7 @@ class Gemma2DecoderLayer(nn.Module):
config: Gemma2Config, config: Gemma2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -194,6 +197,7 @@ class Gemma2DecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
attn_logits_soft_cap=config.attn_logit_softcapping, attn_logits_soft_cap=config.attn_logit_softcapping,
prefix=f"{prefix}.self_attn",
) )
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.mlp = Gemma2MLP( self.mlp = Gemma2MLP(
@ -257,8 +261,11 @@ class Gemma2Model(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[ lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
-1]), config, cache_config, quant_config), config,
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -56,6 +56,7 @@ class Attention(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -135,11 +136,14 @@ class TransformerLayer(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size, self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.attention = Attention(config, quant_config=quant_config) self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config) self.mlp = MLP(config, quant_config=quant_config)
self.post_attention_layernorm = LayerNorm(config.hidden_size, self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -161,11 +165,14 @@ class Transformer(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
TransformerLayer(config, quant_config=quant_config) TransformerLayer(config,
for _ in range(config.num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
]) ])
def forward(self, hidden_states): def forward(self, hidden_states):
@ -252,12 +259,14 @@ class EVA2CLIPModel(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
vision_config = Namespace(**config.vision_config) vision_config = Namespace(**config.vision_config)
self.patch_embedding = PatchEmbedding(vision_config) self.patch_embedding = PatchEmbedding(vision_config)
self.transformer = Transformer(vision_config, self.transformer = Transformer(vision_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config, self.linear_proj = GLU(config,
in_features=config.hidden_size, in_features=config.hidden_size,
quant_config=quant_config) quant_config=quant_config)

View File

@ -84,7 +84,8 @@ class GPT2Attention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,

View File

@ -52,6 +52,7 @@ class GPTBigCodeAttention(nn.Module):
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -92,7 +93,8 @@ class GPTBigCodeAttention(nn.Module):
scale=self.scale, scale=self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -151,6 +153,7 @@ class GPTBigCodeBlock(nn.Module):
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -158,7 +161,10 @@ class GPTBigCodeBlock(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, cache_config, quant_config) self.attn = GPTBigCodeAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config) self.mlp = GPTBigMLP(inner_dim, config, quant_config)
@ -210,7 +216,8 @@ class GPTBigCodeModel(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config), lambda prefix: GPTBigCodeBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h", prefix=f"{prefix}.h",
) )
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

View File

@ -53,6 +53,7 @@ class GPTJAttention(nn.Module):
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
@ -94,7 +95,8 @@ class GPTJAttention(nn.Module):
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -147,12 +149,16 @@ class GPTJBlock(nn.Module):
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
inner_dim = (4 * config.n_embd inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner) if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, cache_config, quant_config) self.attn = GPTJAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.mlp = GPTJMLP(inner_dim, config, quant_config) self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward( def forward(
@ -193,7 +199,8 @@ class GPTJModel(nn.Module):
) )
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.n_layer, config.n_layer,
lambda prefix: GPTJBlock(config, cache_config, quant_config), lambda prefix: GPTJBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h", prefix=f"{prefix}.h",
) )
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

View File

@ -52,6 +52,7 @@ class GPTNeoXAttention(nn.Module):
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
@ -94,7 +95,8 @@ class GPTNeoXAttention(nn.Module):
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -145,6 +147,7 @@ class GPTNeoXLayer(nn.Module):
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
@ -152,7 +155,10 @@ class GPTNeoXLayer(nn.Module):
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, cache_config, quant_config) self.attention = GPTNeoXAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attention")
self.mlp = GPTNeoXMLP(config, quant_config) self.mlp = GPTNeoXMLP(config, quant_config)
def forward( def forward(
@ -205,7 +211,8 @@ class GPTNeoXModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config), lambda prefix: GPTNeoXLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,

View File

@ -161,7 +161,8 @@ class GraniteAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,

View File

@ -164,7 +164,8 @@ class GraniteMoeAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,

View File

@ -1,5 +1,5 @@
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch import torch
from torch import nn from torch import nn
@ -250,7 +250,12 @@ class InternLMDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class InternLM2Model(nn.Module): class InternLM2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -266,7 +271,7 @@ class InternLM2Model(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer( lambda prefix: layer_type(
config, cache_config, quant_config, prefix=prefix), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -316,14 +321,18 @@ class InternLM2Model(nn.Module):
class InternLM2ForCausalLM(nn.Module, SupportsPP): class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = InternLM2Model(vllm_config=vllm_config, self.model = model_type(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size, self.output = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,

View File

@ -14,8 +14,6 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2MLP, InternLM2Model) InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
class InternLM2VEDecoderLayer(nn.Module): class InternLM2VEDecoderLayer(nn.Module):
@ -105,17 +103,9 @@ class InternLM2VEDecoderLayer(nn.Module):
class InternLM2VEModel(InternLM2Model): class InternLM2VEModel(InternLM2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config,
prefix=prefix,
config = vllm_config.model_config.hf_config layer_type=InternLM2VEDecoderLayer)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def forward( def forward(
self, self,
@ -159,7 +149,6 @@ class InternLM2VEModel(InternLM2Model):
class InternLM2VEForCausalLM(InternLM2ForCausalLM): class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config,
prefix=prefix,
self.model = InternLM2VEModel(vllm_config=vllm_config, model_type=InternLM2VEModel)
prefix=maybe_prefix(prefix, "model"))

View File

@ -76,6 +76,7 @@ class JAISAttention(nn.Module):
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -114,7 +115,8 @@ class JAISAttention(nn.Module):
scale=self.scale, scale=self.scale,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -178,6 +180,7 @@ class JAISBlock(nn.Module):
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -185,7 +188,10 @@ class JAISBlock(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, cache_config, quant_config) self.attn = JAISAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config) self.mlp = JAISMLP(inner_dim, config, quant_config)
@ -241,7 +247,8 @@ class JAISModel(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: JAISBlock(config=config, lambda prefix: JAISBlock(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.h", prefix=f"{prefix}.h",
) )

View File

@ -102,7 +102,8 @@ class JambaMambaDecoderLayer(nn.Module):
config: JambaConfig, config: JambaConfig,
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.mamba = MambaMixer(hidden_size= config.hidden_size, self.mamba = MambaMixer(hidden_size= config.hidden_size,
@ -157,6 +158,7 @@ class JambaAttentionDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -198,6 +200,7 @@ class JambaAttentionDecoderLayer(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.attn",
) )
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
@ -287,7 +290,8 @@ class JambaModel(nn.Module):
layer_class(config, layer_class(config,
layer_idx=i, layer_idx=i,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config)) quant_config=quant_config,
prefix=f"{prefix}.layers.{i}"))
self.layers = nn.ModuleList(decoder_layers) self.layers = nn.ModuleList(decoder_layers)
self.final_layernorm = RMSNorm(config.hidden_size, self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)

View File

@ -174,6 +174,7 @@ class LlamaAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(

View File

@ -192,6 +192,7 @@ class MiniCPMAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -246,7 +247,8 @@ class MiniCPMAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -273,6 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -283,6 +286,7 @@ class MiniCPMDecoderLayer(nn.Module):
self.rope_scaling = getattr(config, "rope_scaling", None) self.rope_scaling = getattr(config, "rope_scaling", None)
self.max_position_embeddings = getattr(config, self.max_position_embeddings = getattr(config,
"max_position_embeddings", 8192) "max_position_embeddings", 8192)
self.prefix = prefix
self._init_attn_block() self._init_attn_block()
self._init_ffn_block() self._init_ffn_block()
@ -298,6 +302,7 @@ class MiniCPMDecoderLayer(nn.Module):
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config, cache_config=self.cache_config,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
) )
def _init_ffn_block(self): def _init_ffn_block(self):
@ -388,8 +393,8 @@ class MiniCPMModel(nn.Module):
): ):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config, lambda prefix: MiniCPMDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:

View File

@ -60,6 +60,7 @@ class MiniCPM3Attention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -119,7 +120,8 @@ class MiniCPM3Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -195,6 +197,7 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config, cache_config=self.cache_config,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
) )
@ -209,8 +212,8 @@ class MiniCPM3Model(MiniCPMModel):
): ):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config, lambda prefix: MiniCPM3DecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")

View File

@ -166,7 +166,8 @@ class MixtralAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,

View File

@ -170,6 +170,7 @@ class MixtralAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -219,7 +220,8 @@ class MixtralAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -243,6 +245,7 @@ class MixtralDecoderLayer(nn.Module):
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -255,7 +258,9 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.block_sparse_moe = MixtralMoE(config=config, self.block_sparse_moe = MixtralMoE(config=config,
quant_config=quant_config) quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
@ -311,7 +316,8 @@ class MixtralModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer( lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config), config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (

View File

@ -370,6 +370,7 @@ class MolmoAttention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -427,7 +428,8 @@ class MolmoAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
# Attention output projection. # Attention output projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
@ -517,10 +519,14 @@ class MolmoDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Attention block. # Attention block.
self.self_attn = MolmoAttention(config, cache_config, quant_config) self.self_attn = MolmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
# MLP block. # MLP block.
self.mlp = MolmoMLP(config, quant_config=quant_config) self.mlp = MolmoMLP(config, quant_config=quant_config)
@ -738,7 +744,8 @@ class MolmoModel(nn.Module):
else MolmoDecoderLayer else MolmoDecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: decoder_layer(config, cache_config, quant_config), lambda prefix: decoder_layer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )

View File

@ -50,6 +50,7 @@ class MPTAttention(nn.Module):
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
@ -115,7 +116,8 @@ class MPTAttention(nn.Module):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -176,11 +178,15 @@ class MPTBlock(nn.Module):
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size) self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, cache_config, quant_config) self.attn = MPTAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.norm_2 = nn.LayerNorm(hidden_size) self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config) self.ffn = MPTMLP(config, quant_config)
@ -224,7 +230,8 @@ class MPTModel(nn.Module):
) )
self.start_layer, self.end_layer, self.blocks = make_layers( self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers, config.n_layers,
lambda prefix: MPTBlock(config, cache_config, quant_config), lambda prefix: MPTBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks") prefix=f"{prefix}.blocks")
self.norm_f = nn.LayerNorm(config.d_model) self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias: if config.no_bias:

View File

@ -195,7 +195,8 @@ class NemotronAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,

View File

@ -62,6 +62,7 @@ class OlmoAttention(nn.Module):
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -101,7 +102,8 @@ class OlmoAttention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scaling, scale=self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
# Attention output projection. # Attention output projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
# Attention block. # Attention block.
self.self_attn = OlmoAttention(config, cache_config, quant_config) self.self_attn = OlmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
# MLP block. # MLP block.
self.mlp = OlmoMLP(config, quant_config) self.mlp = OlmoMLP(config, quant_config)
@ -238,8 +244,8 @@ class OlmoModel(nn.Module):
config.hidden_size) config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config lambda prefix: OlmoDecoderLayer(
), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False, elementwise_affine=False,

View File

@ -102,6 +102,7 @@ class OlmoeAttention(nn.Module):
max_position_embeddings: int = 4096, max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -156,7 +157,8 @@ class OlmoeAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -182,6 +184,7 @@ class OlmoeDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -199,6 +202,7 @@ class OlmoeDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = OlmoeMoE( self.mlp = OlmoeMoE(
@ -260,8 +264,11 @@ class OlmoeModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(config, int( lambda prefix: OlmoeDecoderLayer(config,
prefix.split(".")[-1]), cache_config, quant_config), int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5) self.norm = RMSNorm(config.hidden_size, eps=1e-5)

View File

@ -75,6 +75,7 @@ class OrionAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -126,7 +127,8 @@ class OrionAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -150,6 +152,7 @@ class OrionDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -166,6 +169,7 @@ class OrionDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = OrionMLP( self.mlp = OrionMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@ -226,10 +230,7 @@ class OrionModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: OrionDecoderLayer( lambda prefix: OrionDecoderLayer(
config, config, cache_config, quant_config, prefix=prefix),
cache_config,
quant_config,
),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (

View File

@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module):
def __init__(self, def __init__(self,
config: PersimmonConfig, config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.config = config self.config = config
tensor_parallel_world_size = get_tensor_model_parallel_world_size() tensor_parallel_world_size = get_tensor_model_parallel_world_size()
@ -122,7 +123,8 @@ class PersimmonAttention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scaling, scale=self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def _split_heads(self, x: torch.Tensor) -> torch.Tensor: def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: PersimmonConfig, config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = PersimmonAttention(config=config, self.self_attn = PersimmonAttention(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = PersimmonMLP(config, quant_config=quant_config) self.mlp = PersimmonMLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
@ -226,8 +230,8 @@ class PersimmonModel(nn.Module):
config.hidden_size) config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: PersimmonDecoderLayer(config, cache_config, lambda prefix: PersimmonDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)

View File

@ -69,7 +69,8 @@ class PhiAttention(nn.Module):
def __init__(self, def __init__(self,
config: PhiConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -116,7 +117,8 @@ class PhiAttention(nn.Module):
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -167,11 +169,15 @@ class PhiLayer(nn.Module):
def __init__(self, def __init__(self,
config: PhiConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, cache_config, quant_config) self.self_attn = PhiAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = PhiMLP(config, quant_config) self.mlp = PhiMLP(config, quant_config)
def forward( def forward(
@ -210,7 +216,8 @@ class PhiModel(nn.Module):
config.hidden_size) config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: PhiLayer(config, cache_config, quant_config), lambda prefix: PhiLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)

View File

@ -117,6 +117,7 @@ class Phi3SmallSelfAttention(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -214,15 +215,14 @@ class Phi3SmallSelfAttention(nn.Module):
"homo_head": self.homo_heads "homo_head": self.homo_heads
} }
self.attn = Attention( self.attn = Attention(self.num_heads_per_partition,
self.num_heads_per_partition, self.head_dim,
self.head_dim, self.scale,
self.scale, num_kv_heads=self.num_kv_heads_per_partion,
num_kv_heads=self.num_kv_heads_per_partion, cache_config=cache_config,
cache_config=cache_config, quant_config=quant_config,
quant_config=quant_config, blocksparse_params=bs_params,
blocksparse_params=bs_params, prefix=f"{prefix}.attn")
)
def forward( def forward(
self, self,
@ -259,13 +259,15 @@ class Phi3SmallDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Phi3SmallSelfAttention(config, self.self_attn = Phi3SmallSelfAttention(config,
layer_idx, layer_idx,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Phi3SmallMLP(config, quant_config) self.mlp = Phi3SmallMLP(config, quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
@ -315,7 +317,9 @@ class Phi3SmallModel(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Phi3SmallDecoderLayer(config, lambda prefix: Phi3SmallDecoderLayer(config,
int(prefix.split('.')[-1]), int(prefix.split('.')[-1]),
cache_config, quant_config), cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,

View File

@ -294,6 +294,7 @@ class PhiMoEAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[dict] = None, rope_scaling: Optional[dict] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -347,6 +348,7 @@ class PhiMoEAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
@ -371,6 +373,7 @@ class PhiMoEDecoderLayer(nn.Module):
config: PhiMoEConfig, config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -385,6 +388,7 @@ class PhiMoEDecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
rope_scaling=config.rope_scaling, rope_scaling=config.rope_scaling,
prefix=f"{prefix}.self_attn",
) )
self.block_sparse_moe = PhiMoE( self.block_sparse_moe = PhiMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
@ -454,8 +458,8 @@ class PhiMoEModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: PhiMoEDecoderLayer(config, cache_config, lambda prefix: PhiMoEDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,

View File

@ -442,6 +442,7 @@ class QWenAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -478,7 +479,8 @@ class QWenAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -502,6 +504,7 @@ class QWenBlock(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -514,7 +517,8 @@ class QWenBlock(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -568,7 +572,8 @@ class QWenModel(nn.Module):
) )
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: QWenBlock(config, cache_config, quant_config), lambda prefix: QWenBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (

View File

@ -168,6 +168,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -220,7 +221,8 @@ class Qwen2MoeAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -245,6 +247,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
@ -336,7 +340,8 @@ class Qwen2MoeModel(nn.Module):
layer_idx=int( layer_idx=int(
prefix.split(".")[-1]), prefix.split(".")[-1]),
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -167,6 +167,7 @@ class SolarAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(

View File

@ -77,7 +77,8 @@ class StablelmAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -131,7 +132,8 @@ class StablelmAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_key_value_heads, num_kv_heads=self.num_key_value_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -155,9 +157,13 @@ class StablelmDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = StablelmAttention(config, cache_config, quant_config) self.self_attn = StablelmAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = StablelmMLP(config, quant_config) self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05)) getattr(config, "layer_norm_eps", 1e-05))
@ -207,8 +213,8 @@ class StableLMEpochModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: StablelmDecoderLayer(config, cache_config, lambda prefix: StablelmDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",

View File

@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.config = config self.config = config
@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config, self.self_attn = Starcoder2Attention(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon) eps=config.norm_epsilon)
@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Starcoder2DecoderLayer( lambda prefix: Starcoder2DecoderLayer(
config, cache_config, quant_config=quant_config), config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)

View File

@ -93,6 +93,7 @@ class XverseAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -138,7 +139,8 @@ class XverseAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=getattr(config, "bias", False), bias=getattr(config, "bias", False),
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = XverseMLP( self.mlp = XverseMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@ -243,8 +247,8 @@ class XverseModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: XverseDecoderLayer(config, cache_config, lambda prefix: XverseDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -20,6 +20,7 @@ logger = init_logger(__name__)
class CpuPlatform(Platform): class CpuPlatform(Platform):
_enum = PlatformEnum.CPU _enum = PlatformEnum.CPU
device_type: str = "cpu" device_type: str = "cpu"
dispatch_key: str = "CPU"
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:

View File

@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class CudaPlatform(Platform): class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA _enum = PlatformEnum.CUDA
device_type: str = "cuda" device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod @classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:

View File

@ -13,6 +13,7 @@ else:
class HpuPlatform(Platform): class HpuPlatform(Platform):
_enum = PlatformEnum.HPU _enum = PlatformEnum.HPU
device_type: str = "hpu" device_type: str = "hpu"
dispatch_key: str = "HPU"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
class Platform: class Platform:
_enum: PlatformEnum _enum: PlatformEnum
device_type: str device_type: str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
def is_cuda(self) -> bool: def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA return self._enum == PlatformEnum.CUDA

View File

@ -18,6 +18,7 @@ logger = init_logger(__name__)
class OpenVinoPlatform(Platform): class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO _enum = PlatformEnum.OPENVINO
device_type: str = "openvino" device_type: str = "openvino"
dispatch_key: str = "CPU"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
device_type: str = "cuda" device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -17,6 +17,7 @@ logger = init_logger(__name__)
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU
device_type: str = "tpu" device_type: str = "tpu"
dispatch_key: str = "XLA"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -17,6 +17,7 @@ logger = init_logger(__name__)
class XPUPlatform(Platform): class XPUPlatform(Platform):
_enum = PlatformEnum.XPU _enum = PlatformEnum.XPU
device_type: str = "xpu" device_type: str = "xpu"
dispatch_key: str = "XPU"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner):
if previous_hidden_states is not None else {} if previous_hidden_states is not None else {}
# Run model # Run model
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,

View File

@ -1573,6 +1573,7 @@ def direct_register_custom_op(
mutates_args: List[str], mutates_args: List[str],
fake_impl: Optional[Callable] = None, fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None, target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
): ):
""" """
`torch.library.custom_op` can have significant overhead because it `torch.library.custom_op` can have significant overhead because it
@ -1601,7 +1602,7 @@ def direct_register_custom_op(
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str) my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA") my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None: if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl) my_lib._register_fake(op_name, fake_impl)

View File

@ -173,7 +173,8 @@ def unified_v1_flash_attention(
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
current_metadata = get_forward_context() context = get_forward_context()
current_metadata = context.dynamic_forward_context
if current_metadata is None: if current_metadata is None:
# Profiling run. # Profiling run.
return return

View File

@ -447,7 +447,7 @@ class GPUModelRunner:
# Run the decoder. # Run the decoder.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model( hidden_states = self.model(
input_ids=None, input_ids=None,
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
@ -523,7 +523,7 @@ class GPUModelRunner:
num_tokens: int, num_tokens: int,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
with set_forward_context(None): with set_forward_context(None, self.vllm_config):
hidden_states = model( hidden_states = model(
input_ids=None, input_ids=None,
positions=self.positions[:num_tokens], positions=self.positions[:num_tokens],

View File

@ -97,7 +97,7 @@ class EmbeddingModelRunner(
model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record() model_forward_start.record()
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,

View File

@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
} if self.has_inner_state else {} } if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,

View File

@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self._update_inputs_to_capture_for_enc_dec_model( self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs) capture_inputs)
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, self.vllm_config):
graph_runner.capture(**capture_inputs) graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = ( self.graph_runners[virtual_engine][batch_size] = (
@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record() model_forward_start.record()
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,