[torch.compile] support all attention backends (#10558)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-22 14:04:42 -08:00 committed by GitHub
parent db100c5cde
commit eebad39f26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 876 additions and 648 deletions

View File

@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
@ -594,6 +596,7 @@ def _run_encoder_attention_test(
encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor:
'''
Run encoder attention.
@ -623,7 +626,7 @@ def _run_encoder_attention_test(
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor:
'''
Run decoder self-attention test.
@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor:
'''
Run encoder/decoder cross-attention test.
@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
@ -839,7 +844,9 @@ def test_encoder_only(
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used
# during prefill)
@ -863,7 +870,8 @@ def test_encoder_only(
test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt))
test_pt=test_pt,
vllm_config=vllm_config))
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs = _make_test_resources(test_pt)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used
# during prefill)
@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)
# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params,
@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params,
@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params,
@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt)
test_pt=test_pt,
vllm_config=vllm_config)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params,

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)
@ -15,13 +14,19 @@ if TYPE_CHECKING:
ModelRunnerInputBuilderBase)
class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
ENCODER = auto(
) # Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto(
) # Attention between dec. Q and enc. K/V for encoder-decoder
class AttentionType:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER = "encoder_decoder"
class AttentionBackend(ABC):
@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

View File

@ -16,10 +16,8 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -668,23 +666,174 @@ class FlashAttentionImpl(AttentionImpl):
"requires setting cross-attention "
"metadata attributes.")
output = torch.ops.vllm.unified_flash_attention(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
attn_type.value,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads
kv_cache_dtype: str = self.kv_cache_dtype
softmax_scale: float = self.scale
window_size = self.sliding_window
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
logits_soft_cap: Optional[float] = self.logits_soft_cap
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
if (key is not None) and (value is not None):
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the
# KV cache is already populated with the cross-attention
# tensor. Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
# use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1"
)
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size)
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
return output
@ -692,7 +841,7 @@ class FlashAttentionImpl(AttentionImpl):
def _get_query_key_seq_metadata(
attn_metadata,
is_prompt: bool,
attn_type: AttentionType,
attn_type: str,
) -> tuple:
"""
Returns sequence metadata for key and query based on the specified
@ -754,7 +903,7 @@ def _get_query_key_seq_metadata(
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_causal_option(attn_type: AttentionType) -> bool:
def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
@ -770,220 +919,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool:
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
# Convert integer attn_type to enum
try:
attn_type = AttentionType(attn_type_int_val)
except ValueError as err:
raise AttributeError(
f"Invalid attention type {str(attn_type_int_val)}") from err
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
if (key is not None) and (value is not None):
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
# use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1")
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size)
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)

View File

@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
get_kv_cache_torch_dtype, make_tensor_with_pad)
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl):
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
@ -782,174 +781,117 @@ class FlashInferImpl(AttentionImpl):
"are not implemented for "
"FlashInferImpl")
return torch.ops.vllm.unified_flash_infer(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads
kv_cache_dtype: str = self.kv_cache_dtype
softmax_scale: float = self.scale
window_size = self.sliding_window
alibi_slopes = self.alibi_slopes
logits_soft_cap = self.logits_soft_cap
num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
def unified_flash_infer(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashInferMetadata)
attn_metadata: FlashInferMetadata = current_metadata
num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous() # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
window_left = window_size[0] if window_size is not None else -1
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
query,
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous(
) # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
window_left = window_size[0] if window_size is not None else -1
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
)
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
if prefill_output is None and decode_output is not None:
# Decode only batch.
output, num_tokens = decode_output, num_decode_tokens
elif decode_output is None and prefill_output is not None:
# Prefill only batch.
output, num_tokens = prefill_output, num_prefill_tokens
else:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert prefill_output is not None
assert decode_output is not None
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
def unified_flash_infer_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_flash_infer",
op_func=unified_flash_infer,
mutates_args=["kv_cache"],
fake_impl=unified_flash_infer_fake,
)
if prefill_output is None and decode_output is not None:
# Decode only batch.
output, num_tokens = decode_output, num_decode_tokens
elif decode_output is None and prefill_output is not None:
# Prefill only batch.
output, num_tokens = prefill_output, num_prefill_tokens
else:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert prefill_output is not None
assert decode_output is not None
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)

View File

@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

View File

@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.

View File

@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with Pallas attention.

View File

@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

View File

@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_seq_lens(
self,
attn_type: AttentionType,
attn_type: str,
):
'''
Extract appropriate sequence lengths from attention metadata
@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_attn_bias(
self,
attn_type: AttentionType,
attn_type: str,
) -> Optional[List[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def set_attn_bias(
self,
attn_bias: List[torch.Tensor],
attn_type: AttentionType,
attn_type: str,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_seq_len_block_table_args(
self,
attn_type: AttentionType,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> None:
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)

View File

@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata):
def get_seq_len_block_table_args(
attn_metadata,
is_prompt: bool,
attn_type: AttentionType,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
@ -529,7 +529,7 @@ def get_seq_len_block_table_args(
def get_num_prefill_decode_query_kv_tokens(
attn_metadata,
attn_type: AttentionType,
attn_type: str,
) -> Tuple[int, int, int]:
"""
Calculate the number of prefill and decode tokens for query, key/value

View File

@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
def _get_attn_bias(
attn_metadata: XFormersMetadata,
attn_type: AttentionType,
attn_type: str,
) -> Optional[AttentionBias]:
'''
Extract appropriate attention bias from attention metadata
@ -314,7 +314,7 @@ def _get_attn_bias(
def _set_attn_bias(
attn_metadata: XFormersMetadata,
attn_bias: List[Optional[AttentionBias]],
attn_type: AttentionType,
attn_type: str,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: XFormersMetadata,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.

View File

@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import direct_register_custom_op
class Attention(nn.Module):
@ -86,6 +91,18 @@ class Attention(nn.Module):
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = envs.VLLM_USE_V1 or not (
current_platform.is_cuda_alike() or current_platform.is_cpu())
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
def forward(
self,
query: torch.Tensor,
@ -93,17 +110,22 @@ class Attention(nn.Module):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
if self.use_direct_call:
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
@ -112,3 +134,44 @@ class Attention(nn.Module):
s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}"
return s
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention",
"vllm.unified_flash_infer",
"vllm.unified_attention",
"vllm.unified_v1_flash_attention",
])
@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()
self.static_forward_context = {}
def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:

View File

@ -1,21 +1,38 @@
from contextlib import contextmanager
from typing import Any
from dataclasses import dataclass
from typing import Any, Dict, Optional
_forward_context: Any = None
from vllm.config import VllmConfig
def get_forward_context() -> Any:
@dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
_forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext:
"""Get the current forward context."""
assert _forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return _forward_context
@contextmanager
def set_forward_context(context: Any):
def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
global _forward_context
prev_context = _forward_context
_forward_context = context
_forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
try:
yield
finally:

View File

@ -223,6 +223,7 @@ class ArcticAttention(nn.Module):
layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@ -274,7 +275,8 @@ class ArcticAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module):
self.self_attn = ArcticAttention(config,
layer_idx,
cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = ArcticMoE(
config,
layer_id=layer_idx,
@ -380,8 +384,11 @@ class ArcticModel(nn.Module):
org_num_embeddings=self.vocab_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config, int(
prefix.split(".")[-1]), cache_config, quant_config),
lambda prefix: ArcticDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module):
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
else:
self.rotary_emb = get_rope(
self.head_dim,
@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module):
self.head_dim,
self.scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module):
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
cache_config, quant_config),
lambda prefix: BaiChuanDecoderLayer(config,
position_embedding,
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module):
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module):
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module):
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module):
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.d_model
@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module):
num_heads=config.encoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function)
@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module):
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.d_model
@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module):
num_heads=config.decoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module):
self.embed_dim,
config.decoder_attention_heads,
config=config,
prefix=f"{prefix}.encoder_attn",
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
@ -578,7 +591,8 @@ class BartEncoder(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None):
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = ""):
super().__init__()
self.cache_config = cache_config
@ -599,9 +613,13 @@ class BartEncoder(nn.Module):
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList(
[BartEncoderLayer(config,cache_config,quant_config) \
for _ in range(config.encoder_layers)])
self.layers = nn.ModuleList([
BartEncoderLayer(config,
cache_config,
quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.encoder_layers)
])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@ -661,6 +679,7 @@ class BartDecoder(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = "",
):
super().__init__()
self.cache_config = cache_config
@ -683,8 +702,9 @@ class BartDecoder(nn.Module):
)
self.layers = nn.ModuleList(
[BartDecoderLayer(config,cache_config,quant_config) \
for _ in range(config.decoder_layers)])
[BartDecoderLayer(config,cache_config,quant_config,
prefix=f"{prefix}.layers.{layer_idx}") \
for layer_idx in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
@ -759,10 +779,12 @@ class BartModel(nn.Module):
self.encoder = BartEncoder(config,
cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config,
cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,

View File

@ -78,6 +78,7 @@ class BloomAttention(nn.Module):
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
@ -116,7 +117,8 @@ class BloomAttention(nn.Module):
scaling,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -168,14 +170,17 @@ class BloomBlock(nn.Module):
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, cache_config,
quant_config)
self.self_attention = BloomAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config)
@ -242,7 +247,8 @@ class BloomModel(nn.Module):
# Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: BloomBlock(config, cache_config, quant_config),
lambda prefix: BloomBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
# Final Layer Norm

View File

@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module):
config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module):
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = ChameleonMLP(
hidden_size=self.hidden_size,
@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = ChameleonMLP(
hidden_size=self.hidden_size,
@ -855,7 +861,8 @@ class ChameleonModel(nn.Module):
config.num_hidden_layers,
lambda prefix: decoder_layer(config=config,
cache_config=cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)

View File

@ -230,6 +230,7 @@ class GLMAttention(nn.Module):
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
@ -285,7 +286,8 @@ class GLMAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -364,6 +366,7 @@ class GLMBlock(nn.Module):
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
@ -377,7 +380,10 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, cache_config, quant_config)
self.self_attention = GLMAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
@ -446,7 +452,8 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self.start_layer, self.end_layer, self.layers = make_layers(
self.num_layers,
lambda prefix: GLMBlock(config, cache_config, quant_config),
lambda prefix: GLMBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
@ -500,16 +507,22 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config)
self.encoder = GLMTransformer(config,
cache_config,
quant_config,
prefix=f"{prefix}.encoder")
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.output_layer")
vision_config_flag = getattr(config, 'vision_config', None)
if vision_config_flag is not None:
self.vision_config = Namespace(**config.vision_config)
self.vision = EVA2CLIPModel(self.config, quant_config)
self.vision = EVA2CLIPModel(self.config,
quant_config,
prefix=f"{prefix}.vision")
else:
self.vision = None

View File

@ -120,6 +120,7 @@ class CohereAttention(nn.Module):
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
@ -175,7 +176,8 @@ class CohereAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
@ -215,13 +217,15 @@ class CohereDecoderLayer(nn.Module):
def __init__(self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config,
cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
@ -271,8 +275,8 @@ class CohereModel(nn.Module):
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CohereDecoderLayer(config, cache_config,
quant_config),
lambda prefix: CohereDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)

View File

@ -154,6 +154,7 @@ class DbrxAttention(nn.Module):
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
@ -208,7 +209,8 @@ class DbrxAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -234,10 +236,14 @@ class DbrxFusedNormAttention(nn.Module):
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, cache_config, quant_config)
self.attn = DbrxAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
@ -269,10 +275,14 @@ class DbrxBlock(nn.Module):
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
quant_config)
self.norm_attn_norm = DbrxFusedNormAttention(
config,
cache_config,
quant_config,
prefix=f"{prefix}.norm_attn_norm")
self.ffn = DbrxMoE(config, quant_config)
def forward(
@ -308,7 +318,8 @@ class DbrxModel(nn.Module):
)
self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers,
lambda prefix: DbrxBlock(config, cache_config, quant_config),
lambda prefix: DbrxBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks",
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)

View File

@ -184,6 +184,7 @@ class DeepseekAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -236,7 +237,8 @@ class DeepseekAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -261,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -277,6 +280,7 @@ class DeepseekDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
@ -346,7 +350,8 @@ class DeepseekModel(nn.Module):
lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (

View File

@ -268,7 +268,8 @@ class DeepseekV2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -174,6 +174,7 @@ class ExaoneAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
@ -219,7 +220,7 @@ class ExaoneBlockAttention(nn.Module):
quant_config=quant_config,
bias=bias,
cache_config=cache_config,
prefix=prefix,
prefix=f"{prefix}.attention",
)
def forward(

View File

@ -84,6 +84,7 @@ class FalconAttention(nn.Module):
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
@ -158,7 +159,8 @@ class FalconAttention(nn.Module):
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
@ -171,14 +173,16 @@ class FalconAttention(nn.Module):
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -241,12 +245,16 @@ class FalconDecoderLayer(nn.Module):
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, cache_config,
quant_config)
self.self_attention = FalconAttention(
config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.mlp = FalconMLP(config, quant_config)
self.config = config
@ -357,8 +365,8 @@ class FalconModel(nn.Module):
# Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(config, cache_config,
quant_config),
lambda prefix: FalconDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
# Final Layer Norm

View File

@ -35,10 +35,12 @@ class Florence2LanguageModel(nn.Module):
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
self.encoder = BartEncoder(config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.decoder")
if self.config.tie_word_embeddings:
self.encoder.embed_tokens.weight = self.shared.weight
@ -99,7 +101,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
self.config = config
self.model = Florence2LanguageModel(vllm_config=vllm_config,
prefix=prefix)
prefix=f"{prefix}.model")
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
@ -198,7 +200,7 @@ class Florence2ForConditionalGeneration(nn.Module):
# TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix,
prefix=f"{prefix}.language_model",
)
@property

View File

@ -174,7 +174,8 @@ class GemmaAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -95,7 +95,8 @@ class Gemma2Attention(nn.Module):
rope_theta: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None:
attn_logits_soft_cap: Optional[float] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
@ -154,7 +155,8 @@ class Gemma2Attention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap)
logits_soft_cap=attn_logits_soft_cap,
prefix=f"{prefix}.attn")
def forward(
self,
@ -179,6 +181,7 @@ class Gemma2DecoderLayer(nn.Module):
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -194,6 +197,7 @@ class Gemma2DecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
attn_logits_soft_cap=config.attn_logit_softcapping,
prefix=f"{prefix}.self_attn",
)
self.hidden_size = config.hidden_size
self.mlp = Gemma2MLP(
@ -257,8 +261,11 @@ class Gemma2Model(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
-1]), config, cache_config, quant_config),
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
config,
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -56,6 +56,7 @@ class Attention(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.hidden_size = config.hidden_size
@ -135,11 +136,14 @@ class TransformerLayer(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = Attention(config, quant_config=quant_config)
self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config)
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
@ -161,11 +165,14 @@ class Transformer(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
TransformerLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(self, hidden_states):
@ -252,12 +259,14 @@ class EVA2CLIPModel(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
vision_config = Namespace(**config.vision_config)
self.patch_embedding = PatchEmbedding(vision_config)
self.transformer = Transformer(vision_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)

View File

@ -84,7 +84,8 @@ class GPT2Attention(nn.Module):
self.head_dim,
scale=self.scale,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -52,6 +52,7 @@ class GPTBigCodeAttention(nn.Module):
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
@ -92,7 +93,8 @@ class GPTBigCodeAttention(nn.Module):
scale=self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -151,6 +153,7 @@ class GPTBigCodeBlock(nn.Module):
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@ -158,7 +161,10 @@ class GPTBigCodeBlock(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
self.attn = GPTBigCodeAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
@ -210,7 +216,8 @@ class GPTBigCodeModel(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
lambda prefix: GPTBigCodeBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

View File

@ -53,6 +53,7 @@ class GPTJAttention(nn.Module):
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.total_num_heads = config.num_attention_heads
@ -94,7 +95,8 @@ class GPTJAttention(nn.Module):
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -147,12 +149,16 @@ class GPTJBlock(nn.Module):
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, cache_config, quant_config)
self.attn = GPTJAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward(
@ -193,7 +199,8 @@ class GPTJModel(nn.Module):
)
self.start_layer, self.end_layer, self.h = make_layers(
config.n_layer,
lambda prefix: GPTJBlock(config, cache_config, quant_config),
lambda prefix: GPTJBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

View File

@ -52,6 +52,7 @@ class GPTNeoXAttention(nn.Module):
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.total_num_heads = config.num_attention_heads
@ -94,7 +95,8 @@ class GPTNeoXAttention(nn.Module):
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -145,6 +147,7 @@ class GPTNeoXLayer(nn.Module):
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
@ -152,7 +155,10 @@ class GPTNeoXLayer(nn.Module):
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
self.attention = GPTNeoXAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attention")
self.mlp = GPTNeoXMLP(config, quant_config)
def forward(
@ -205,7 +211,8 @@ class GPTNeoXModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
lambda prefix: GPTNeoXLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
self.final_layer_norm = nn.LayerNorm(config.hidden_size,

View File

@ -161,7 +161,8 @@ class GraniteAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -164,7 +164,8 @@ class GraniteMoeAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch
from torch import nn
@ -250,7 +250,12 @@ class InternLMDecoderLayer(nn.Module):
@support_torch_compile
class InternLM2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
@ -266,7 +271,7 @@ class InternLM2Model(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer(
lambda prefix: layer_type(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -316,14 +321,18 @@ class InternLM2Model(nn.Module):
class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = model_type(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,

View File

@ -14,8 +14,6 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
class InternLM2VEDecoderLayer(nn.Module):
@ -105,17 +103,9 @@ class InternLM2VEDecoderLayer(nn.Module):
class InternLM2VEModel(InternLM2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=InternLM2VEDecoderLayer)
def forward(
self,
@ -159,7 +149,6 @@ class InternLM2VEModel(InternLM2Model):
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.model = InternLM2VEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
super().__init__(vllm_config=vllm_config,
prefix=prefix,
model_type=InternLM2VEModel)

View File

@ -76,6 +76,7 @@ class JAISAttention(nn.Module):
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
@ -114,7 +115,8 @@ class JAISAttention(nn.Module):
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -178,6 +180,7 @@ class JAISBlock(nn.Module):
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@ -185,7 +188,10 @@ class JAISBlock(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, cache_config, quant_config)
self.attn = JAISAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config)
@ -241,7 +247,8 @@ class JAISModel(nn.Module):
config.num_hidden_layers,
lambda prefix: JAISBlock(config=config,
cache_config=cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.h",
)

View File

@ -102,7 +102,8 @@ class JambaMambaDecoderLayer(nn.Module):
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.mamba = MambaMixer(hidden_size= config.hidden_size,
@ -157,6 +158,7 @@ class JambaAttentionDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -198,6 +200,7 @@ class JambaAttentionDecoderLayer(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
num_experts = config.layers_num_experts[layer_idx]
@ -287,7 +290,8 @@ class JambaModel(nn.Module):
layer_class(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config))
quant_config=quant_config,
prefix=f"{prefix}.layers.{i}"))
self.layers = nn.ModuleList(decoder_layers)
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

View File

@ -174,6 +174,7 @@ class LlamaAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(

View File

@ -192,6 +192,7 @@ class MiniCPMAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -246,7 +247,8 @@ class MiniCPMAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -273,6 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@ -283,6 +286,7 @@ class MiniCPMDecoderLayer(nn.Module):
self.rope_scaling = getattr(config, "rope_scaling", None)
self.max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.prefix = prefix
self._init_attn_block()
self._init_ffn_block()
@ -298,6 +302,7 @@ class MiniCPMDecoderLayer(nn.Module):
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
)
def _init_ffn_block(self):
@ -388,8 +393,8 @@ class MiniCPMModel(nn.Module):
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
quant_config),
lambda prefix: MiniCPMDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:

View File

@ -60,6 +60,7 @@ class MiniCPM3Attention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -119,7 +120,8 @@ class MiniCPM3Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -195,6 +197,7 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
)
@ -209,8 +212,8 @@ class MiniCPM3Model(MiniCPMModel):
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
quant_config),
lambda prefix: MiniCPM3DecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")

View File

@ -166,7 +166,8 @@ class MixtralAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -170,6 +170,7 @@ class MixtralAttention(nn.Module):
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -219,7 +220,8 @@ class MixtralAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -243,6 +245,7 @@ class MixtralDecoderLayer(nn.Module):
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -255,7 +258,9 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.block_sparse_moe = MixtralMoE(config=config,
quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
@ -311,7 +316,8 @@ class MixtralModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config),
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (

View File

@ -370,6 +370,7 @@ class MolmoAttention(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -427,7 +428,8 @@ class MolmoAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
# Attention output projection.
self.o_proj = RowParallelLinear(
@ -517,10 +519,14 @@ class MolmoDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
# Attention block.
self.self_attn = MolmoAttention(config, cache_config, quant_config)
self.self_attn = MolmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = MolmoMLP(config, quant_config=quant_config)
@ -738,7 +744,8 @@ class MolmoModel(nn.Module):
else MolmoDecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer(config, cache_config, quant_config),
lambda prefix: decoder_layer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)

View File

@ -50,6 +50,7 @@ class MPTAttention(nn.Module):
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
@ -115,7 +116,8 @@ class MPTAttention(nn.Module):
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -176,11 +178,15 @@ class MPTBlock(nn.Module):
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, cache_config, quant_config)
self.attn = MPTAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
@ -224,7 +230,8 @@ class MPTModel(nn.Module):
)
self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers,
lambda prefix: MPTBlock(config, cache_config, quant_config),
lambda prefix: MPTBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks")
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:

View File

@ -195,7 +195,8 @@ class NemotronAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -62,6 +62,7 @@ class OlmoAttention(nn.Module):
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@ -101,7 +102,8 @@ class OlmoAttention(nn.Module):
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
# Attention output projection.
self.o_proj = RowParallelLinear(
@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module):
def __init__(self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Attention block.
self.self_attn = OlmoAttention(config, cache_config, quant_config)
self.self_attn = OlmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = OlmoMLP(config, quant_config)
@ -238,8 +244,8 @@ class OlmoModel(nn.Module):
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config
),
lambda prefix: OlmoDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,

View File

@ -102,6 +102,7 @@ class OlmoeAttention(nn.Module):
max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -156,7 +157,8 @@ class OlmoeAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -182,6 +184,7 @@ class OlmoeDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -199,6 +202,7 @@ class OlmoeDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = OlmoeMoE(
@ -260,8 +264,11 @@ class OlmoeModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(config, int(
prefix.split(".")[-1]), cache_config, quant_config),
lambda prefix: OlmoeDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5)

View File

@ -75,6 +75,7 @@ class OrionAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -126,7 +127,8 @@ class OrionAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -150,6 +152,7 @@ class OrionDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -166,6 +169,7 @@ class OrionDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = OrionMLP(
hidden_size=self.hidden_size,
@ -226,10 +230,7 @@ class OrionModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OrionDecoderLayer(
config,
cache_config,
quant_config,
),
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (

View File

@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
tensor_parallel_world_size = get_tensor_model_parallel_world_size()
@ -122,7 +123,8 @@ class PersimmonAttention(nn.Module):
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PersimmonAttention(config=config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = PersimmonMLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
@ -226,8 +230,8 @@ class PersimmonModel(nn.Module):
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PersimmonDecoderLayer(config, cache_config,
quant_config),
lambda prefix: PersimmonDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

View File

@ -69,7 +69,8 @@ class PhiAttention(nn.Module):
def __init__(self,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
@ -116,7 +117,8 @@ class PhiAttention(nn.Module):
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -167,11 +169,15 @@ class PhiLayer(nn.Module):
def __init__(self,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, cache_config, quant_config)
self.self_attn = PhiAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = PhiMLP(config, quant_config)
def forward(
@ -210,7 +216,8 @@ class PhiModel(nn.Module):
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PhiLayer(config, cache_config, quant_config),
lambda prefix: PhiLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

View File

@ -117,6 +117,7 @@ class Phi3SmallSelfAttention(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
@ -214,15 +215,14 @@ class Phi3SmallSelfAttention(nn.Module):
"homo_head": self.homo_heads
}
self.attn = Attention(
self.num_heads_per_partition,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads_per_partion,
cache_config=cache_config,
quant_config=quant_config,
blocksparse_params=bs_params,
)
self.attn = Attention(self.num_heads_per_partition,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads_per_partion,
cache_config=cache_config,
quant_config=quant_config,
blocksparse_params=bs_params,
prefix=f"{prefix}.attn")
def forward(
self,
@ -259,13 +259,15 @@ class Phi3SmallDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Phi3SmallSelfAttention(config,
layer_idx,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Phi3SmallMLP(config, quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
@ -315,7 +317,9 @@ class Phi3SmallModel(nn.Module):
config.num_hidden_layers,
lambda prefix: Phi3SmallDecoderLayer(config,
int(prefix.split('.')[-1]),
cache_config, quant_config),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,

View File

@ -294,6 +294,7 @@ class PhiMoEAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[dict] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -347,6 +348,7 @@ class PhiMoEAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
@ -371,6 +373,7 @@ class PhiMoEDecoderLayer(nn.Module):
config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -385,6 +388,7 @@ class PhiMoEDecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=config.rope_scaling,
prefix=f"{prefix}.self_attn",
)
self.block_sparse_moe = PhiMoE(
num_experts=config.num_local_experts,
@ -454,8 +458,8 @@ class PhiMoEModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PhiMoEDecoderLayer(config, cache_config,
quant_config),
lambda prefix: PhiMoEDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps,

View File

@ -442,6 +442,7 @@ class QWenAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
@ -478,7 +479,8 @@ class QWenAttention(nn.Module):
self.head_dim,
self.scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -502,6 +504,7 @@ class QWenBlock(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -514,7 +517,8 @@ class QWenBlock(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -568,7 +572,8 @@ class QWenModel(nn.Module):
)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: QWenBlock(config, cache_config, quant_config),
lambda prefix: QWenBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (

View File

@ -168,6 +168,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -220,7 +221,8 @@ class Qwen2MoeAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -245,6 +247,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
@ -336,7 +340,8 @@ class Qwen2MoeModel(nn.Module):
layer_idx=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -167,6 +167,7 @@ class SolarAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(

View File

@ -77,7 +77,8 @@ class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
@ -131,7 +132,8 @@ class StablelmAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -155,9 +157,13 @@ class StablelmDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config, cache_config, quant_config)
self.self_attn = StablelmAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
@ -207,8 +213,8 @@ class StableLMEpochModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: StablelmDecoderLayer(config, cache_config,
quant_config),
lambda prefix: StablelmDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
norm_eps = getattr(config, "norm_eps",

View File

@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config,
cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Starcoder2DecoderLayer(
config, cache_config, quant_config=quant_config),
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers",
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)

View File

@ -93,6 +93,7 @@ class XverseAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -138,7 +139,8 @@ class XverseAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module):
quant_config=quant_config,
bias=getattr(config, "bias", False),
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = XverseMLP(
hidden_size=self.hidden_size,
@ -243,8 +247,8 @@ class XverseModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: XverseDecoderLayer(config, cache_config,
quant_config),
lambda prefix: XverseDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -20,6 +20,7 @@ logger = init_logger(__name__)
class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_type: str = "cpu"
dispatch_key: str = "CPU"
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:

View File

@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA
device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:

View File

@ -13,6 +13,7 @@ else:
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
device_type: str = "hpu"
dispatch_key: str = "HPU"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
class Platform:
_enum: PlatformEnum
device_type: str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA

View File

@ -18,6 +18,7 @@ logger = init_logger(__name__)
class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
device_type: str = "openvino"
dispatch_key: str = "CPU"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -17,6 +17,7 @@ logger = init_logger(__name__)
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_type: str = "tpu"
dispatch_key: str = "XLA"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -17,6 +17,7 @@ logger = init_logger(__name__)
class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
device_type: str = "xpu"
dispatch_key: str = "XPU"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner):
if previous_hidden_states is not None else {}
# Run model
with set_forward_context(model_input.attn_metadata):
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@ -1573,6 +1573,7 @@ def direct_register_custom_op(
mutates_args: List[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
):
"""
`torch.library.custom_op` can have significant overhead because it
@ -1601,7 +1602,7 @@ def direct_register_custom_op(
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)

View File

@ -173,7 +173,8 @@ def unified_v1_flash_attention(
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
current_metadata = get_forward_context()
context = get_forward_context()
current_metadata = context.dynamic_forward_context
if current_metadata is None:
# Profiling run.
return

View File

@ -447,7 +447,7 @@ class GPUModelRunner:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=None,
positions=self.positions[:num_input_tokens],
@ -523,7 +523,7 @@ class GPUModelRunner:
num_tokens: int,
kv_caches: List[torch.Tensor],
) -> torch.Tensor:
with set_forward_context(None):
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=None,
positions=self.positions[:num_tokens],

View File

@ -97,7 +97,7 @@ class EmbeddingModelRunner(
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
with set_forward_context(model_input.attn_metadata):
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
} if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata):
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs)
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, self.vllm_config):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
with set_forward_context(model_input.attn_metadata):
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,