mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 08:57:54 +08:00
[Misc] Enhance attention selector (#4751)
This commit is contained in:
parent
e7c46b9527
commit
0fca3cdcf2
@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
|
|
||||||
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
||||||
assert len(input_positions) == len(input_tokens)
|
assert len(input_positions) == len(input_tokens)
|
||||||
assert attn_metadata.kv_cache_dtype == "auto"
|
|
||||||
assert attn_metadata.num_prefills == prefill_batch_size
|
assert attn_metadata.num_prefills == prefill_batch_size
|
||||||
if enforce_eager:
|
if enforce_eager:
|
||||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||||
|
|||||||
@ -5,9 +5,9 @@ from vllm.attention.layer import Attention
|
|||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"Attention",
|
||||||
"AttentionBackend",
|
"AttentionBackend",
|
||||||
"AttentionMetadata",
|
"AttentionMetadata",
|
||||||
"Attention",
|
|
||||||
"get_attn_backend",
|
|
||||||
"AttentionMetadataPerStage",
|
"AttentionMetadataPerStage",
|
||||||
|
"get_attn_backend",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]):
|
|||||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||||
# in block 0, and 1st slot in block 1, respectively.
|
# in block 0, and 1st slot in block 1, respectively.
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
# The kv cache's data type.
|
|
||||||
kv_cache_dtype: str
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.num_prefill_tokens > 0:
|
if self.num_prefill_tokens > 0:
|
||||||
@ -116,6 +114,7 @@ class AttentionImpl(ABC):
|
|||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -127,6 +126,6 @@ class AttentionImpl(ABC):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
kv_scale: float,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -140,16 +140,18 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
self.sliding_window = ((sliding_window, sliding_window)
|
|
||||||
if sliding_window is not None else (-1, -1))
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
|
self.sliding_window = ((sliding_window, sliding_window)
|
||||||
|
if sliding_window is not None else (-1, -1))
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -167,7 +169,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
||||||
kv_scale: float,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
@ -196,8 +198,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype, kv_scale)
|
||||||
kv_scale)
|
|
||||||
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
@ -264,7 +265,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
decode_meta.block_tables,
|
decode_meta.block_tables,
|
||||||
decode_meta.seq_lens_tensor,
|
decode_meta.seq_lens_tensor,
|
||||||
decode_meta.max_seq_len,
|
decode_meta.max_seq_len,
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
|
|||||||
@ -149,20 +149,33 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
|
self.alibi_slopes = alibi_slopes
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
raise ValueError("Sliding window is not supported in FlashInfer.")
|
raise ValueError("Sliding window is not supported in FlashInfer.")
|
||||||
self.sliding_window = (-1, -1)
|
self.sliding_window = (-1, -1)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.scale = scale
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_size = head_size
|
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
||||||
|
|
||||||
def forward(self, query: torch.Tensor, key: torch.Tensor,
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
value: torch.Tensor, kv_cache: Optional[torch.Tensor],
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
attn_metadata: AttentionMetadata[FlashInferMetadata],
|
|
||||||
kv_scale: float):
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: Optional[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata[FlashInferMetadata],
|
||||||
|
kv_scale: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert kv_scale == 1.0
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
@ -183,7 +196,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
kv_cache[:, 0],
|
kv_cache[:, 0],
|
||||||
kv_cache[:, 1],
|
kv_cache[:, 1],
|
||||||
attn_metadata.slot_mapping.flatten(),
|
attn_metadata.slot_mapping.flatten(),
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
|
|||||||
@ -138,25 +138,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
self.sliding_window = ((sliding_window, sliding_window)
|
|
||||||
if sliding_window is not None else (-1, -1))
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
|
self.sliding_window = ((sliding_window, sliding_window)
|
||||||
|
if sliding_window is not None else (-1, -1))
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||||
if head_size not in suppored_head_sizes:
|
if head_size not in supported_head_sizes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
f"Supported head sizes are: {supported_head_sizes}.")
|
||||||
|
|
||||||
self.use_naive_attn = False
|
self.use_naive_attn = False
|
||||||
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
||||||
@ -229,7 +231,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
kv_scale,
|
kv_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -323,7 +325,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
decode_meta.block_tables,
|
decode_meta.block_tables,
|
||||||
decode_meta.seq_lens_tensor,
|
decode_meta.seq_lens_tensor,
|
||||||
decode_meta.max_seq_len,
|
decode_meta.max_seq_len,
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
|
|||||||
@ -83,26 +83,32 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
self.sliding_window = sliding_window
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
assert len(alibi_slopes) == num_heads
|
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
self.need_mask = (self.alibi_slopes is not None
|
self.sliding_window = sliding_window
|
||||||
or self.sliding_window is not None)
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
self.need_mask = (self.alibi_slopes is not None
|
||||||
if head_size not in suppored_head_sizes:
|
or self.sliding_window is not None)
|
||||||
|
|
||||||
|
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
f"Supported head sizes are: {supported_head_sizes}.")
|
||||||
|
if kv_cache_dtype != "auto":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Torch SDPA backend does not support FP8 KV cache. "
|
||||||
|
"Please use xFormers backend instead.")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -111,7 +117,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
kv_scale: float,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
|
|
||||||
@ -124,6 +130,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
assert kv_scale == 1.0
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
@ -136,8 +143,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype, kv_scale)
|
||||||
kv_scale)
|
|
||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
@ -195,7 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
attn_metadata.block_tables,
|
attn_metadata.block_tables,
|
||||||
attn_metadata.seq_lens_tensor,
|
attn_metadata.seq_lens_tensor,
|
||||||
attn_metadata.max_seq_len,
|
attn_metadata.max_seq_len,
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
|
|||||||
@ -149,15 +149,17 @@ class XFormersImpl(AttentionImpl):
|
|||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
self.sliding_window = sliding_window
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -175,7 +177,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata[XFormersMetadata],
|
attn_metadata: AttentionMetadata[XFormersMetadata],
|
||||||
kv_scale: float,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
|
|
||||||
@ -188,7 +190,6 @@ class XFormersImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
num_tokens, hidden_size = query.shape
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
@ -203,8 +204,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype, kv_scale)
|
||||||
kv_scale)
|
|
||||||
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
@ -262,7 +262,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
decode_meta.block_tables,
|
decode_meta.block_tables,
|
||||||
decode_meta.seq_lens_tensor,
|
decode_meta.seq_lens_tensor,
|
||||||
decode_meta.max_seq_len,
|
decode_meta.max_seq_len,
|
||||||
attn_metadata.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
from vllm.attention.backends.abstract import (AttentionMetadata,
|
||||||
AttentionMetadataPerStage)
|
AttentionMetadataPerStage)
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -29,10 +30,24 @@ class Attention(nn.Module):
|
|||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backend = get_attn_backend(torch.get_default_dtype())
|
if cache_config is not None:
|
||||||
impl_cls = self.backend.get_impl_cls()
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
if num_kv_heads is None:
|
||||||
|
num_kv_heads = num_heads
|
||||||
|
# During model initialization, the default dtype is set as the model
|
||||||
|
# weight and activation dtype.
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
|
||||||
|
sliding_window, dtype, kv_cache_dtype,
|
||||||
|
block_size)
|
||||||
|
impl_cls = attn_backend.get_impl_cls()
|
||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window)
|
alibi_slopes, sliding_window)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -21,8 +21,18 @@ class _Backend(enum.Enum):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
def get_attn_backend(
|
||||||
backend = _which_attn_to_use(dtype)
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int,
|
||||||
|
) -> Type[AttentionBackend]:
|
||||||
|
backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
|
||||||
|
sliding_window, dtype, kv_cache_dtype,
|
||||||
|
block_size)
|
||||||
if backend == _Backend.FLASH_ATTN:
|
if backend == _Backend.FLASH_ATTN:
|
||||||
logger.info("Using FlashAttention-2 backend.")
|
logger.info("Using FlashAttention-2 backend.")
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
|||||||
return TorchSDPABackend
|
return TorchSDPABackend
|
||||||
elif backend == _Backend.FLASHINFER:
|
elif backend == _Backend.FLASHINFER:
|
||||||
logger.info("Using Flashinfer backend.")
|
logger.info("Using Flashinfer backend.")
|
||||||
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
|
logger.warning("Eager mode is enforced for the Flashinfer backend.")
|
||||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||||
return FlashInferBackend
|
return FlashInferBackend
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid attention backend.")
|
raise ValueError("Invalid attention backend.")
|
||||||
|
|
||||||
|
|
||||||
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
def _which_attn_to_use(
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int,
|
||||||
|
) -> _Backend:
|
||||||
"""Returns which flash attention backend to use."""
|
"""Returns which flash attention backend to use."""
|
||||||
if is_cpu():
|
if is_cpu():
|
||||||
return _Backend.TORCH_SDPA
|
return _Backend.TORCH_SDPA
|
||||||
|
|||||||
@ -2,26 +2,29 @@ from typing import Optional
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
||||||
get_model_loader)
|
get_model_loader)
|
||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import (
|
||||||
get_architecture_class_name, get_model_architecture)
|
get_architecture_class_name, get_model_architecture)
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
|
||||||
*, model_config: ModelConfig, load_config: LoadConfig,
|
device_config: DeviceConfig, parallel_config: ParallelConfig,
|
||||||
device_config: DeviceConfig, parallel_config: ParallelConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
loader = get_model_loader(load_config)
|
loader = get_model_loader(load_config)
|
||||||
return loader.load_model(model_config=model_config,
|
return loader.load_model(model_config=model_config,
|
||||||
device_config=device_config,
|
device_config=device_config,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
vision_language_config=vision_language_config,
|
vision_language_config=vision_language_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
scheduler_config=scheduler_config)
|
scheduler_config=scheduler_config,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -9,9 +9,9 @@ import huggingface_hub
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
LoRAConfig, ModelConfig, ParallelConfig,
|
||||||
VisionLanguageConfig)
|
SchedulerConfig, VisionLanguageConfig)
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -77,15 +77,16 @@ def _get_model_initialization_kwargs(
|
|||||||
return extra_kwargs
|
return extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
def _initialize_model(
|
def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
|
||||||
model_config: ModelConfig, load_config: LoadConfig,
|
lora_config: Optional[LoRAConfig],
|
||||||
lora_config: Optional[LoRAConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
"""Initialize a model with the given configurations."""
|
"""Initialize a model with the given configurations."""
|
||||||
model_class = get_model_architecture(model_config)[0]
|
model_class = get_model_architecture(model_config)[0]
|
||||||
quant_config = _get_quantization_config(model_config, load_config)
|
quant_config = _get_quantization_config(model_config, load_config)
|
||||||
|
|
||||||
return model_class(config=model_config.hf_config,
|
return model_class(config=model_config.hf_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
**_get_model_initialization_kwargs(
|
**_get_model_initialization_kwargs(
|
||||||
model_class, lora_config, vision_language_config))
|
model_class, lora_config, vision_language_config))
|
||||||
@ -103,7 +104,8 @@ class BaseModelLoader(ABC):
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
scheduler_config: SchedulerConfig,
|
||||||
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
"""Load a model with the given configurations."""
|
"""Load a model with the given configurations."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -216,11 +218,13 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
scheduler_config: SchedulerConfig,
|
||||||
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model = _initialize_model(model_config, self.load_config,
|
model = _initialize_model(model_config, self.load_config,
|
||||||
lora_config, vision_language_config)
|
lora_config, vision_language_config,
|
||||||
|
cache_config)
|
||||||
model.load_weights(
|
model.load_weights(
|
||||||
self._get_weights_iterator(model_config.model,
|
self._get_weights_iterator(model_config.model,
|
||||||
model_config.revision,
|
model_config.revision,
|
||||||
@ -253,11 +257,13 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
scheduler_config: SchedulerConfig,
|
||||||
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model = _initialize_model(model_config, self.load_config,
|
model = _initialize_model(model_config, self.load_config,
|
||||||
lora_config, vision_language_config)
|
lora_config, vision_language_config,
|
||||||
|
cache_config)
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
@ -286,9 +292,12 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
return tensorizer_weights_iterator(tensorizer_args)
|
return tensorizer_weights_iterator(tensorizer_args)
|
||||||
|
|
||||||
def _load_model_unserialized(
|
def _load_model_unserialized(
|
||||||
self, model_config: ModelConfig, device_config: DeviceConfig,
|
self,
|
||||||
lora_config: Optional[LoRAConfig],
|
model_config: ModelConfig,
|
||||||
vision_language_config: Optional[VisionLanguageConfig]
|
device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
cache_config: CacheConfig,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""Load an unserialized model with tensorizer.
|
"""Load an unserialized model with tensorizer.
|
||||||
|
|
||||||
@ -299,15 +308,19 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model = _initialize_model(model_config, self.load_config,
|
model = _initialize_model(model_config, self.load_config,
|
||||||
lora_config, vision_language_config)
|
lora_config, vision_language_config,
|
||||||
|
cache_config)
|
||||||
|
|
||||||
model.load_weights(self._get_weights_iterator())
|
model.load_weights(self._get_weights_iterator())
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
def _load_model_serialized(
|
def _load_model_serialized(
|
||||||
self, model_config: ModelConfig, device_config: DeviceConfig,
|
self,
|
||||||
lora_config: Optional[LoRAConfig],
|
model_config: ModelConfig,
|
||||||
vision_language_config: Optional[VisionLanguageConfig]
|
device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
cache_config: CacheConfig,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""Load a serialized model with tensorizer.
|
"""Load a serialized model with tensorizer.
|
||||||
|
|
||||||
@ -321,6 +334,7 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
extra_kwargs = _get_model_initialization_kwargs(
|
extra_kwargs = _get_model_initialization_kwargs(
|
||||||
model_class, lora_config, vision_language_config)
|
model_class, lora_config, vision_language_config)
|
||||||
extra_kwargs["quant_config"] = quant_config
|
extra_kwargs["quant_config"] = quant_config
|
||||||
|
extra_kwargs["cache_config"] = cache_config
|
||||||
|
|
||||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||||
tensorizer_config.model_class = model_class
|
tensorizer_config.model_class = model_class
|
||||||
@ -335,16 +349,19 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
scheduler_config: SchedulerConfig,
|
||||||
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
self._verify_config(model_config, parallel_config)
|
self._verify_config(model_config, parallel_config)
|
||||||
|
|
||||||
if is_vllm_serialized_tensorizer(self.tensorizer_config):
|
if is_vllm_serialized_tensorizer(self.tensorizer_config):
|
||||||
return self._load_model_serialized(model_config, device_config,
|
return self._load_model_serialized(model_config, device_config,
|
||||||
lora_config,
|
lora_config,
|
||||||
vision_language_config)
|
vision_language_config,
|
||||||
|
cache_config)
|
||||||
return self._load_model_unserialized(model_config, device_config,
|
return self._load_model_unserialized(model_config, device_config,
|
||||||
lora_config,
|
lora_config,
|
||||||
vision_language_config)
|
vision_language_config,
|
||||||
|
cache_config)
|
||||||
|
|
||||||
|
|
||||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -215,6 +216,7 @@ class ArcticAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: ArcticConfig,
|
config: ArcticConfig,
|
||||||
layer_idx: Optional[int] = None,
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -265,7 +267,8 @@ class ArcticAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -288,6 +291,7 @@ class ArcticDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: ArcticConfig,
|
config: ArcticConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -297,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
|
|||||||
self.use_residual = config.use_residual and is_moe_layer
|
self.use_residual = config.use_residual and is_moe_layer
|
||||||
self.self_attn = ArcticAttention(config,
|
self.self_attn = ArcticAttention(config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
|
cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.block_sparse_moe = ArcticMoE(
|
self.block_sparse_moe = ArcticMoE(
|
||||||
config,
|
config,
|
||||||
@ -356,6 +361,7 @@ class ArcticModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ArcticConfig,
|
config: ArcticConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -366,7 +372,10 @@ class ArcticModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
org_num_embeddings=self.vocab_size)
|
org_num_embeddings=self.vocab_size)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
ArcticDecoderLayer(config, layer_idx, quant_config=quant_config)
|
ArcticDecoderLayer(config,
|
||||||
|
layer_idx,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self._attn_implementation = config._attn_implementation
|
self._attn_implementation = config._attn_implementation
|
||||||
@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: ArcticConfig,
|
config: ArcticConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = ArcticModel(config, quant_config)
|
self.model = ArcticModel(config, cache_config, quant_config)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
self.vocab_size,
|
self.vocab_size,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -111,6 +111,7 @@ class BaiChuanAttention(nn.Module):
|
|||||||
position_embedding: str,
|
position_embedding: str,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -162,7 +163,10 @@ class BaiChuanAttention(nn.Module):
|
|||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
)
|
)
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
position_embedding: str,
|
position_embedding: str,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -197,6 +202,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
position_embedding=position_embedding,
|
position_embedding=position_embedding,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = BaiChuanMLP(
|
self.mlp = BaiChuanMLP(
|
||||||
@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
position_embedding: str,
|
position_embedding: str,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -255,7 +262,8 @@ class BaiChuanModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
BaiChuanDecoderLayer(config, position_embedding, quant_config)
|
BaiChuanDecoderLayer(config, position_embedding, cache_config,
|
||||||
|
quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -304,13 +312,15 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
position_embedding: str,
|
position_embedding: str,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = BaiChuanModel(config, position_embedding, quant_config)
|
self.model = BaiChuanModel(config, position_embedding, cache_config,
|
||||||
|
quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
@ -389,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
if config.hidden_size == 4096: # baichuan2 7b
|
if config.hidden_size == 4096: # baichuan2 7b
|
||||||
super().__init__(config, "ROPE", quant_config, lora_config)
|
super().__init__(config, "ROPE", cache_config, quant_config,
|
||||||
|
lora_config)
|
||||||
else: # baichuan 13b, baichuan2 13b
|
else: # baichuan 13b, baichuan2 13b
|
||||||
super().__init__(config, "ALIBI", quant_config, lora_config)
|
super().__init__(config, "ALIBI", cache_config, quant_config,
|
||||||
|
lora_config)
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||||
@ -404,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__(config, "ROPE", quant_config, lora_config)
|
super().__init__(config, "ROPE", cache_config, quant_config,
|
||||||
|
lora_config)
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from torch import nn
|
|||||||
from transformers import BloomConfig
|
from transformers import BloomConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -71,6 +72,7 @@ class BloomAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: BloomConfig,
|
config: BloomConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -108,7 +110,8 @@ class BloomAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scaling,
|
scaling,
|
||||||
alibi_slopes=alibi_slopes)
|
alibi_slopes=alibi_slopes,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -158,6 +161,7 @@ class BloomBlock(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: BloomConfig,
|
config: BloomConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -165,7 +169,8 @@ class BloomBlock(nn.Module):
|
|||||||
|
|
||||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||||
eps=config.layer_norm_epsilon)
|
eps=config.layer_norm_epsilon)
|
||||||
self.self_attention = BloomAttention(config, quant_config)
|
self.self_attention = BloomAttention(config, cache_config,
|
||||||
|
quant_config)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
hidden_size, eps=config.layer_norm_epsilon)
|
hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = BloomMLP(config, quant_config)
|
self.mlp = BloomMLP(config, quant_config)
|
||||||
@ -214,6 +219,7 @@ class BloomModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: BloomConfig,
|
config: BloomConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -229,7 +235,7 @@ class BloomModel(nn.Module):
|
|||||||
|
|
||||||
# Transformer blocks
|
# Transformer blocks
|
||||||
self.h = nn.ModuleList([
|
self.h = nn.ModuleList([
|
||||||
BloomBlock(config, quant_config)
|
BloomBlock(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: BloomConfig,
|
config: BloomConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = BloomModel(config, quant_config)
|
self.transformer = BloomModel(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from torch import nn
|
|||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -34,6 +34,7 @@ class GLMAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -90,6 +91,7 @@ class GLMAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -167,6 +169,7 @@ class GLMBlock(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -181,7 +184,7 @@ class GLMBlock(nn.Module):
|
|||||||
eps=config.layernorm_epsilon)
|
eps=config.layernorm_epsilon)
|
||||||
|
|
||||||
# Self attention.
|
# Self attention.
|
||||||
self.self_attention = GLMAttention(config, quant_config)
|
self.self_attention = GLMAttention(config, cache_config, quant_config)
|
||||||
self.hidden_dropout = config.hidden_dropout
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
|
||||||
# Layernorm on the attention output
|
# Layernorm on the attention output
|
||||||
@ -237,6 +240,7 @@ class GLMTransformer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -246,8 +250,10 @@ class GLMTransformer(nn.Module):
|
|||||||
self.num_layers = config.num_layers
|
self.num_layers = config.num_layers
|
||||||
|
|
||||||
# Transformer layers.
|
# Transformer layers.
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList([
|
||||||
[GLMBlock(config, quant_config) for i in range(self.num_layers)])
|
GLMBlock(config, cache_config, quant_config)
|
||||||
|
for i in range(self.num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
if self.post_layer_norm:
|
if self.post_layer_norm:
|
||||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||||
@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -292,7 +299,7 @@ class ChatGLMModel(nn.Module):
|
|||||||
self.num_layers = config.num_layers
|
self.num_layers = config.num_layers
|
||||||
self.multi_query_group_num = config.multi_query_group_num
|
self.multi_query_group_num = config.multi_query_group_num
|
||||||
self.kv_channels = config.kv_channels
|
self.kv_channels = config.kv_channels
|
||||||
self.encoder = GLMTransformer(config, quant_config)
|
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
||||||
|
|
||||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
@ -334,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ChatGLMConfig,
|
config: ChatGLMConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config: ChatGLMConfig = config
|
self.config: ChatGLMConfig = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = ChatGLMModel(config, quant_config)
|
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.transformer.output_layer.weight
|
self.lm_head_weight = self.transformer.output_layer.weight
|
||||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from transformers import CohereConfig
|
from transformers import CohereConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -124,6 +125,7 @@ class CohereAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: CohereConfig,
|
config: CohereConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -180,6 +182,7 @@ class CohereAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||||
@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: CohereConfig,
|
config: CohereConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = CohereAttention(config, quant_config=quant_config)
|
self.self_attn = CohereAttention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
self.mlp = CohereMLP(config, quant_config=quant_config)
|
self.mlp = CohereMLP(config, quant_config=quant_config)
|
||||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||||
@ -258,6 +264,7 @@ class CohereModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: CohereConfig,
|
config: CohereConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -266,7 +273,7 @@ class CohereModel(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
CohereDecoderLayer(config, quant_config=quant_config)
|
CohereDecoderLayer(config, cache_config, quant_config=quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||||
@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: CohereConfig,
|
config: CohereConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -306,7 +314,7 @@ class CohereForCausalLM(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||||
scale=config.logit_scale)
|
scale=config.logit_scale)
|
||||||
self.model = CohereModel(config, quant_config)
|
self.model = CohereModel(config, cache_config, quant_config)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -166,6 +167,7 @@ class DbrxAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -221,6 +223,7 @@ class DbrxAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -279,10 +282,12 @@ class DbrxBlock(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
|
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
|
||||||
|
quant_config)
|
||||||
self.ffn = DbrxExperts(config, quant_config)
|
self.ffn = DbrxExperts(config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -308,6 +313,7 @@ class DbrxModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -315,8 +321,10 @@ class DbrxModel(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.d_model,
|
config.d_model,
|
||||||
)
|
)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList([
|
||||||
[DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
|
DbrxBlock(config, cache_config, quant_config)
|
||||||
|
for _ in range(config.n_layers)
|
||||||
|
])
|
||||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
if hasattr(module, "bias") and isinstance(module.bias,
|
if hasattr(module, "bias") and isinstance(module.bias,
|
||||||
@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
self.transformer = DbrxModel(config, quant_config)
|
self.transformer = DbrxModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.d_model,
|
config.d_model,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from typing import Iterable, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Optional[PretrainedConfig] = None,
|
config: Optional[PretrainedConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
||||||
delattr(config, "num_key_value_heads_per_layer")
|
delattr(config, "num_key_value_heads_per_layer")
|
||||||
super().__init__(config=config,
|
super().__init__(config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
lora_config=lora_config)
|
lora_config=lora_config)
|
||||||
|
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -178,6 +179,7 @@ class DeepseekAttention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -229,7 +231,8 @@ class DeepseekAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -252,6 +255,7 @@ class DeepseekDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -267,6 +271,7 @@ class DeepseekDecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if (config.n_routed_experts is not None
|
if (config.n_routed_experts is not None
|
||||||
@ -321,6 +326,7 @@ class DeepseekModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -332,7 +338,10 @@ class DeepseekModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
|
DeepseekDecoderLayer(config,
|
||||||
|
layer_idx,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = DeepseekModel(config, quant_config)
|
self.model = DeepseekModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from torch.nn import LayerNorm
|
|||||||
from transformers import FalconConfig as HF_FalconConfig
|
from transformers import FalconConfig as HF_FalconConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -77,6 +78,7 @@ class FalconAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: FalconConfig,
|
config: FalconConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -168,7 +170,8 @@ class FalconAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.inv_norm_factor,
|
scale=self.inv_norm_factor,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -229,12 +232,14 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: FalconConfig,
|
config: FalconConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.self_attention = FalconAttention(config, quant_config)
|
self.self_attention = FalconAttention(config, cache_config,
|
||||||
|
quant_config)
|
||||||
self.mlp = FalconMLP(config, quant_config)
|
self.mlp = FalconMLP(config, quant_config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@ -311,6 +316,7 @@ class FalconModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: FalconConfig,
|
config: FalconConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -327,7 +333,7 @@ class FalconModel(nn.Module):
|
|||||||
|
|
||||||
# Transformer blocks
|
# Transformer blocks
|
||||||
self.h = nn.ModuleList([
|
self.h = nn.ModuleList([
|
||||||
FalconDecoderLayer(config, quant_config)
|
FalconDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -359,12 +365,13 @@ class FalconForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: FalconConfig,
|
config: FalconConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = FalconModel(config, quant_config)
|
self.transformer = FalconModel(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from torch import nn
|
|||||||
from transformers import GemmaConfig
|
from transformers import GemmaConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
@ -107,6 +107,7 @@ class GemmaAttention(nn.Module):
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -155,7 +156,8 @@ class GemmaAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -177,6 +179,7 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GemmaConfig,
|
config: GemmaConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -188,6 +191,7 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
head_dim=config.head_dim,
|
head_dim=config.head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
rope_theta=config.rope_theta,
|
rope_theta=config.rope_theta,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = GemmaMLP(
|
self.mlp = GemmaMLP(
|
||||||
@ -236,6 +240,7 @@ class GemmaModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GemmaConfig,
|
config: GemmaConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -246,7 +251,7 @@ class GemmaModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
GemmaDecoderLayer(config, quant_config)
|
GemmaDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -309,6 +314,7 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GemmaConfig,
|
config: GemmaConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -316,7 +322,7 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = GemmaModel(config, quant_config)
|
self.model = GemmaModel(config, cache_config, quant_config)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from torch import nn
|
|||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -45,6 +46,7 @@ class GPT2Attention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPT2Config,
|
config: GPT2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -70,7 +72,10 @@ class GPT2Attention(nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
scale=self.scale,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -122,6 +127,7 @@ class GPT2Block(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPT2Config,
|
config: GPT2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -130,7 +136,7 @@ class GPT2Block(nn.Module):
|
|||||||
hidden_size)
|
hidden_size)
|
||||||
|
|
||||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.attn = GPT2Attention(config, quant_config)
|
self.attn = GPT2Attention(config, cache_config, quant_config)
|
||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = GPT2MLP(inner_dim, config, quant_config)
|
self.mlp = GPT2MLP(inner_dim, config, quant_config)
|
||||||
|
|
||||||
@ -163,6 +169,7 @@ class GPT2Model(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPT2Config,
|
config: GPT2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -174,7 +181,7 @@ class GPT2Model(nn.Module):
|
|||||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||||
self.h = nn.ModuleList([
|
self.h = nn.ModuleList([
|
||||||
GPT2Block(config, quant_config)
|
GPT2Block(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPT2Config,
|
config: GPT2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = GPT2Model(config, quant_config)
|
self.transformer = GPT2Model(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.transformer.wte.weight
|
self.lm_head_weight = self.transformer.wte.weight
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from torch import nn
|
|||||||
from transformers import GPTBigCodeConfig
|
from transformers import GPTBigCodeConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -46,6 +47,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTBigCodeConfig,
|
config: GPTBigCodeConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -85,7 +87,8 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -143,6 +146,7 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTBigCodeConfig,
|
config: GPTBigCodeConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -151,7 +155,7 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
hidden_size)
|
hidden_size)
|
||||||
|
|
||||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.attn = GPTBigCodeAttention(config, quant_config)
|
self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
|
||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||||
|
|
||||||
@ -184,6 +188,7 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTBigCodeConfig,
|
config: GPTBigCodeConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -195,7 +200,7 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||||
self.h = nn.ModuleList([
|
self.h = nn.ModuleList([
|
||||||
GPTBigCodeBlock(config, quant_config)
|
GPTBigCodeBlock(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
@ -224,12 +229,13 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTBigCodeConfig,
|
config: GPTBigCodeConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = GPTBigCodeModel(config, quant_config)
|
self.transformer = GPTBigCodeModel(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.transformer.wte.weight
|
self.lm_head_weight = self.transformer.wte.weight
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from torch import nn
|
|||||||
from transformers import GPTJConfig
|
from transformers import GPTJConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -45,6 +46,7 @@ class GPTJAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTJConfig,
|
config: GPTJConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -83,7 +85,10 @@ class GPTJAttention(nn.Module):
|
|||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads, self.head_size, scaling)
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -135,13 +140,14 @@ class GPTJBlock(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTJConfig,
|
config: GPTJConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = (4 * config.n_embd
|
inner_dim = (4 * config.n_embd
|
||||||
if config.n_inner is None else config.n_inner)
|
if config.n_inner is None else config.n_inner)
|
||||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
self.attn = GPTJAttention(config, quant_config)
|
self.attn = GPTJAttention(config, cache_config, quant_config)
|
||||||
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -169,6 +175,7 @@ class GPTJModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTJConfig,
|
config: GPTJConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -178,8 +185,10 @@ class GPTJModel(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
)
|
)
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList([
|
||||||
[GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
|
GPTJBlock(config, cache_config, quant_config)
|
||||||
|
for _ in range(config.n_layer)
|
||||||
|
])
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -207,13 +216,14 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTJConfig,
|
config: GPTJConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
assert not config.tie_word_embeddings
|
assert not config.tie_word_embeddings
|
||||||
self.transformer = GPTJModel(config, quant_config)
|
self.transformer = GPTJModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.n_embd,
|
config.n_embd,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from torch import nn
|
|||||||
from transformers import GPTNeoXConfig
|
from transformers import GPTNeoXConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -45,6 +46,7 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTNeoXConfig,
|
config: GPTNeoXConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -84,7 +86,10 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads, self.head_size, scaling)
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -134,6 +139,7 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTNeoXConfig,
|
config: GPTNeoXConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -142,7 +148,7 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.attention = GPTNeoXAttention(config, quant_config)
|
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
|
||||||
self.mlp = GPTNeoXMLP(config, quant_config)
|
self.mlp = GPTNeoXMLP(config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -182,6 +188,7 @@ class GPTNeoXModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GPTNeoXConfig,
|
config: GPTNeoXConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -192,7 +199,7 @@ class GPTNeoXModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
GPTNeoXLayer(config, quant_config)
|
GPTNeoXLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||||
@ -223,12 +230,13 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.gpt_neox = GPTNeoXModel(config, quant_config)
|
self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
|
||||||
self.embed_out = ParallelLMHead(
|
self.embed_out = ParallelLMHead(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -64,6 +65,7 @@ class InternLM2Attention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -114,7 +116,8 @@ class InternLM2Attention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -136,6 +139,7 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -151,6 +155,7 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.feed_forward = InternLM2MLP(
|
self.feed_forward = InternLM2MLP(
|
||||||
@ -196,6 +201,7 @@ class InternLM2Model(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -207,7 +213,7 @@ class InternLM2Model(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
InternLMDecoderLayer(config, quant_config)
|
InternLMDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -239,12 +245,13 @@ class InternLM2ForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = InternLM2Model(config, quant_config)
|
self.model = InternLM2Model(config, cache_config, quant_config)
|
||||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -69,6 +70,7 @@ class JAISAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: JAISConfig,
|
config: JAISConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -108,6 +110,7 @@ class JAISAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -170,6 +173,7 @@ class JAISBlock(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: JAISConfig,
|
config: JAISConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -178,7 +182,7 @@ class JAISBlock(nn.Module):
|
|||||||
hidden_size)
|
hidden_size)
|
||||||
|
|
||||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.attn = JAISAttention(config, quant_config)
|
self.attn = JAISAttention(config, cache_config, quant_config)
|
||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
||||||
|
|
||||||
@ -211,6 +215,7 @@ class JAISModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: JAISConfig,
|
config: JAISConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -228,7 +233,7 @@ class JAISModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.embeddings_scale = config.mup_embeddings_scale
|
self.embeddings_scale = config.mup_embeddings_scale
|
||||||
self.h = nn.ModuleList([
|
self.h = nn.ModuleList([
|
||||||
JAISBlock(config, quant_config)
|
JAISBlock(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
@ -262,12 +267,13 @@ class JAISLMHeadModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: JAISConfig,
|
config: JAISConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = JAISModel(config, quant_config)
|
self.transformer = JAISModel(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.transformer.wte.weight
|
self.lm_head_weight = self.transformer.wte.weight
|
||||||
if hasattr(config, "width_scale"):
|
if hasattr(config, "width_scale"):
|
||||||
self.output_logits_scale = config.width_scale
|
self.output_logits_scale = config.width_scale
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from torch import nn
|
|||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -94,6 +94,7 @@ class LlamaAttention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -153,7 +154,8 @@ class LlamaAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=sliding_window)
|
sliding_window=sliding_window,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -176,6 +178,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -204,6 +207,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=attention_bias,
|
bias=attention_bias,
|
||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -251,6 +255,7 @@ class LlamaModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -267,7 +272,7 @@ class LlamaModel(nn.Module):
|
|||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
LlamaDecoderLayer(config, quant_config)
|
LlamaDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -332,12 +337,16 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = LlamaModel(config, quant_config, lora_config=lora_config)
|
self.model = LlamaModel(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
lora_config=lora_config)
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
if lora_config:
|
if lora_config:
|
||||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from torch import nn
|
|||||||
from transformers import CLIPVisionModel, LlavaConfig
|
from transformers import CLIPVisionModel, LlavaConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import VisionLanguageConfig
|
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -62,6 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: "LlavaConfig",
|
config: "LlavaConfig",
|
||||||
vision_language_config: VisionLanguageConfig,
|
vision_language_config: VisionLanguageConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional["QuantizationConfig"] = None) -> None:
|
quant_config: Optional["QuantizationConfig"] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -85,7 +86,8 @@ class LlavaForConditionalGeneration(nn.Module):
|
|||||||
projector_hidden_act=config.projector_hidden_act)
|
projector_hidden_act=config.projector_hidden_act)
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.language_model = LlamaModel(config.text_config, quant_config)
|
self.language_model = LlamaModel(config.text_config, cache_config,
|
||||||
|
quant_config)
|
||||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
self.unpadded_vocab_size,
|
self.unpadded_vocab_size,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -181,6 +181,7 @@ class MiniCPMAttention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -234,7 +235,8 @@ class MiniCPMAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -259,6 +261,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -275,6 +278,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||||
@ -330,6 +334,7 @@ class MiniCPMModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -346,7 +351,7 @@ class MiniCPMModel(nn.Module):
|
|||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
MiniCPMDecoderLayer(config, quant_config)
|
MiniCPMDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -413,6 +418,7 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -421,6 +427,7 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = MiniCPMModel(config,
|
self.model = MiniCPMModel(config,
|
||||||
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
lora_config=lora_config)
|
lora_config=lora_config)
|
||||||
unpadded_vocab_size = config.vocab_size
|
unpadded_vocab_size = config.vocab_size
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from transformers import MixtralConfig
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -252,6 +252,7 @@ class MixtralAttention(nn.Module):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
max_position: int = 4096 * 32,
|
max_position: int = 4096 * 32,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
sliding_window: Optional[int] = None) -> None:
|
sliding_window: Optional[int] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -313,6 +314,7 @@ class MixtralAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -335,6 +337,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -348,6 +351,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
sliding_window=config.sliding_window,
|
sliding_window=config.sliding_window,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.block_sparse_moe = MixtralMoE(
|
self.block_sparse_moe = MixtralMoE(
|
||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
@ -394,6 +398,7 @@ class MixtralModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -410,7 +415,9 @@ class MixtralModel(nn.Module):
|
|||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
MixtralDecoderLayer(config, quant_config=quant_config)
|
MixtralDecoderLayer(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -460,12 +467,14 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = MixtralModel(config,
|
self.model = MixtralModel(config,
|
||||||
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
lora_config=lora_config)
|
lora_config=lora_config)
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from torch import nn
|
|||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -157,14 +158,17 @@ class MixtralMoE(nn.Module):
|
|||||||
|
|
||||||
class MixtralAttention(nn.Module):
|
class MixtralAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
hidden_size: int,
|
self,
|
||||||
num_heads: int,
|
hidden_size: int,
|
||||||
num_kv_heads: int,
|
num_heads: int,
|
||||||
max_position: int = 4096 * 32,
|
num_kv_heads: int,
|
||||||
rope_theta: float = 10000,
|
max_position: int = 4096 * 32,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
rope_theta: float = 10000,
|
||||||
sliding_window: Optional[int] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -215,6 +219,7 @@ class MixtralAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -237,6 +242,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -250,6 +256,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
sliding_window=config.sliding_window,
|
sliding_window=config.sliding_window,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.block_sparse_moe = MixtralMoE(config=config,
|
self.block_sparse_moe = MixtralMoE(config=config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
@ -292,6 +299,7 @@ class MixtralModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -303,7 +311,9 @@ class MixtralModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
MixtralDecoderLayer(config, quant_config=quant_config)
|
MixtralDecoderLayer(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -332,12 +342,13 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = MixtralModel(config, quant_config)
|
self.model = MixtralModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -43,6 +44,7 @@ class MPTAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MPTConfig,
|
config: MPTConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -107,7 +109,8 @@ class MPTAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
scaling,
|
scaling,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -166,12 +169,13 @@ class MPTBlock(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MPTConfig,
|
config: MPTConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.d_model
|
hidden_size = config.d_model
|
||||||
self.norm_1 = nn.LayerNorm(hidden_size)
|
self.norm_1 = nn.LayerNorm(hidden_size)
|
||||||
self.attn = MPTAttention(config, quant_config)
|
self.attn = MPTAttention(config, cache_config, quant_config)
|
||||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||||
self.ffn = MPTMLP(config, quant_config)
|
self.ffn = MPTMLP(config, quant_config)
|
||||||
|
|
||||||
@ -201,6 +205,7 @@ class MPTModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MPTConfig,
|
config: MPTConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -211,8 +216,10 @@ class MPTModel(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.d_model,
|
config.d_model,
|
||||||
)
|
)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList([
|
||||||
[MPTBlock(config, quant_config) for _ in range(config.n_layers)])
|
MPTBlock(config, cache_config, quant_config)
|
||||||
|
for _ in range(config.n_layers)
|
||||||
|
])
|
||||||
self.norm_f = nn.LayerNorm(config.d_model)
|
self.norm_f = nn.LayerNorm(config.d_model)
|
||||||
if config.no_bias:
|
if config.no_bias:
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MPTConfig,
|
config: MPTConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -253,7 +261,7 @@ class MPTForCausalLM(nn.Module):
|
|||||||
assert config.tie_word_embeddings
|
assert config.tie_word_embeddings
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.transformer = MPTModel(config, quant_config)
|
self.transformer = MPTModel(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.transformer.wte.weight
|
self.lm_head_weight = self.transformer.wte.weight
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from transformers import OlmoConfig
|
from transformers import OlmoConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -55,6 +56,7 @@ class OlmoAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OlmoConfig,
|
config: OlmoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -93,7 +95,8 @@ class OlmoAttention(nn.Module):
|
|||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scaling)
|
scale=self.scaling,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
# Attention output projection.
|
# Attention output projection.
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
@ -175,10 +178,11 @@ class OlmoDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: OlmoConfig,
|
config: OlmoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Attention block.
|
# Attention block.
|
||||||
self.self_attn = OlmoAttention(config, quant_config)
|
self.self_attn = OlmoAttention(config, cache_config, quant_config)
|
||||||
|
|
||||||
# MLP block.
|
# MLP block.
|
||||||
self.mlp = OlmoMLP(config, quant_config)
|
self.mlp = OlmoMLP(config, quant_config)
|
||||||
@ -217,6 +221,7 @@ class OlmoModel(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: OlmoConfig,
|
config: OlmoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -224,7 +229,7 @@ class OlmoModel(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
OlmoDecoderLayer(config, quant_config)
|
OlmoDecoderLayer(config, cache_config, quant_config)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = nn.LayerNorm(config.hidden_size,
|
self.norm = nn.LayerNorm(config.hidden_size,
|
||||||
@ -271,10 +276,11 @@ class OlmoForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: OlmoConfig,
|
config: OlmoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = OlmoModel(config, quant_config)
|
self.model = OlmoModel(config, cache_config, quant_config)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
self.lm_head_weight = self.model.embed_tokens.weight
|
self.lm_head_weight = self.model.embed_tokens.weight
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from torch import nn
|
|||||||
from transformers import OPTConfig
|
from transformers import OPTConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -61,6 +62,7 @@ class OPTAttention(nn.Module):
|
|||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -88,7 +90,8 @@ class OPTAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scaling)
|
scale=self.scaling,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -108,6 +111,7 @@ class OPTDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OPTConfig,
|
config: OPTConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -117,6 +121,7 @@ class OPTDecoderLayer(nn.Module):
|
|||||||
embed_dim=self.embed_dim,
|
embed_dim=self.embed_dim,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
bias=config.enable_bias,
|
bias=config.enable_bias,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.do_layer_norm_before = config.do_layer_norm_before
|
self.do_layer_norm_before = config.do_layer_norm_before
|
||||||
@ -181,6 +186,7 @@ class OPTDecoder(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OPTConfig,
|
config: OPTConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -226,7 +232,7 @@ class OPTDecoder(nn.Module):
|
|||||||
self.final_layer_norm = None
|
self.final_layer_norm = None
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
OPTDecoderLayer(config, quant_config)
|
OPTDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -259,10 +265,11 @@ class OPTModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OPTConfig,
|
config: OPTConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.decoder = OPTDecoder(config, quant_config)
|
self.decoder = OPTDecoder(config, cache_config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -279,12 +286,13 @@ class OPTForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = OPTModel(config, quant_config)
|
self.model = OPTModel(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -68,6 +69,7 @@ class OrionAttention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -118,7 +120,8 @@ class OrionAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -140,6 +143,7 @@ class OrionDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -155,6 +159,7 @@ class OrionDecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = OrionMLP(
|
self.mlp = OrionMLP(
|
||||||
@ -202,6 +207,7 @@ class OrionModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -213,7 +219,7 @@ class OrionModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
OrionDecoderLayer(config, quant_config)
|
OrionDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -245,12 +251,13 @@ class OrionForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = OrionModel(config, quant_config)
|
self.model = OrionModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -42,6 +42,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -63,6 +64,7 @@ class PhiAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.total_num_heads = config.num_attention_heads
|
self.total_num_heads = config.num_attention_heads
|
||||||
@ -105,7 +107,10 @@ class PhiAttention(nn.Module):
|
|||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads, self.head_size, scaling)
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -155,11 +160,12 @@ class PhiLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.self_attn = PhiAttention(config, quant_config)
|
self.self_attn = PhiAttention(config, cache_config, quant_config)
|
||||||
self.mlp = PhiMLP(config, quant_config)
|
self.mlp = PhiMLP(config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -186,6 +192,7 @@ class PhiModel(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -193,7 +200,7 @@ class PhiModel(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
PhiLayer(config, quant_config)
|
PhiLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.model = PhiModel(config, quant_config)
|
self.model = PhiModel(config, cache_config, quant_config)
|
||||||
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -68,6 +69,7 @@ class QWenAttention(nn.Module):
|
|||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -101,7 +103,10 @@ class QWenAttention(nn.Module):
|
|||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -123,6 +128,7 @@ class QWenBlock(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -135,6 +141,7 @@ class QWenBlock(nn.Module):
|
|||||||
config.max_position_embeddings,
|
config.max_position_embeddings,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|
||||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
@ -175,6 +182,7 @@ class QWenModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -186,7 +194,7 @@ class QWenModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.h = nn.ModuleList([
|
self.h = nn.ModuleList([
|
||||||
QWenBlock(config, quant_config)
|
QWenBlock(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
@ -218,12 +226,13 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = QWenModel(config, quant_config)
|
self.transformer = QWenModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from torch import nn
|
|||||||
from transformers import Qwen2Config
|
from transformers import Qwen2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -87,6 +87,7 @@ class Qwen2Attention(nn.Module):
|
|||||||
max_position: int = 4096 * 32,
|
max_position: int = 4096 * 32,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
use_sliding_window: bool = False,
|
use_sliding_window: bool = False,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
sliding_window: Optional[int] = None) -> None:
|
sliding_window: Optional[int] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -137,7 +138,8 @@ class Qwen2Attention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=self.sliding_window)
|
sliding_window=self.sliding_window,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -160,6 +162,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: Qwen2Config,
|
config: Qwen2Config,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -175,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
use_sliding_window=use_sliding_window,
|
use_sliding_window=use_sliding_window,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
sliding_window=config.sliding_window)
|
sliding_window=config.sliding_window)
|
||||||
self.mlp = Qwen2MLP(
|
self.mlp = Qwen2MLP(
|
||||||
@ -222,6 +226,7 @@ class Qwen2Model(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen2Config,
|
config: Qwen2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -234,7 +239,7 @@ class Qwen2Model(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Qwen2DecoderLayer(config, layer_idx, quant_config)
|
Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -287,6 +292,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen2Config,
|
config: Qwen2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -294,7 +300,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2Model(config, quant_config)
|
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||||
|
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
self.lm_head_weight = self.model.embed_tokens.weight
|
self.lm_head_weight = self.model.embed_tokens.weight
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -187,6 +188,7 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -238,7 +240,8 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -276,6 +280,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if (config.num_experts is not None
|
if (config.num_experts is not None
|
||||||
@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -339,7 +345,10 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
|
Qwen2MoeDecoderLayer(config,
|
||||||
|
layer_idx,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2MoeModel(config, quant_config)
|
self.model = Qwen2MoeModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -72,6 +73,7 @@ class StablelmAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -124,7 +126,8 @@ class StablelmAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_key_value_heads)
|
num_kv_heads=self.num_key_value_heads,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = StablelmAttention(config)
|
self.self_attn = StablelmAttention(config, cache_config, quant_config)
|
||||||
self.mlp = StablelmMLP(config, quant_config)
|
self.mlp = StablelmMLP(config, quant_config)
|
||||||
norm_eps = getattr(config, "norm_eps",
|
norm_eps = getattr(config, "norm_eps",
|
||||||
getattr(config, "layer_norm_eps", 1e-05))
|
getattr(config, "layer_norm_eps", 1e-05))
|
||||||
@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
@ -195,7 +200,7 @@ class StableLMEpochModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
StablelmDecoderLayer(config, quant_config)
|
StablelmDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
norm_eps = getattr(config, "norm_eps",
|
norm_eps = getattr(config, "norm_eps",
|
||||||
@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = StableLMEpochModel(config, quant_config)
|
self.model = StableLMEpochModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from torch import nn
|
|||||||
from transformers import Starcoder2Config
|
from transformers import Starcoder2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Starcoder2Config,
|
config: Starcoder2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -101,6 +103,7 @@ class Starcoder2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Starcoder2Config,
|
config: Starcoder2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
|
self.self_attn = Starcoder2Attention(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
|
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.norm_epsilon)
|
eps=config.norm_epsilon)
|
||||||
@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Starcoder2Config,
|
config: Starcoder2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -201,7 +208,9 @@ class Starcoder2Model(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Starcoder2DecoderLayer(config, quant_config=quant_config)
|
Starcoder2DecoderLayer(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
||||||
@ -226,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Starcoder2Config,
|
config: Starcoder2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = Starcoder2Model(config, quant_config=quant_config)
|
self.model = Starcoder2Model(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -89,6 +89,7 @@ class XverseAttention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -133,7 +134,8 @@ class XverseAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=sliding_window)
|
sliding_window=sliding_window,
|
||||||
|
cache_config=cache_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -175,6 +178,7 @@ class XverseDecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=getattr(config, "bias", False),
|
bias=getattr(config, "bias", False),
|
||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
self.mlp = XverseMLP(
|
self.mlp = XverseMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -221,6 +225,7 @@ class XverseModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -237,7 +242,7 @@ class XverseModel(nn.Module):
|
|||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
XverseDecoderLayer(config, quant_config)
|
XverseDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config=None,
|
lora_config=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = XverseModel(config, quant_config)
|
self.model = XverseModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class CacheEngine:
|
|||||||
|
|
||||||
self.head_size = model_config.get_head_size()
|
self.head_size = model_config.get_head_size()
|
||||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||||
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
|
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||||
@ -43,7 +43,15 @@ class CacheEngine:
|
|||||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||||
|
|
||||||
# Get attention backend.
|
# Get attention backend.
|
||||||
self.attn_backend = get_attn_backend(model_config.dtype)
|
self.attn_backend = get_attn_backend(
|
||||||
|
model_config.get_num_attention_heads(parallel_config),
|
||||||
|
self.head_size,
|
||||||
|
self.num_kv_heads,
|
||||||
|
model_config.get_sliding_window(),
|
||||||
|
model_config.dtype,
|
||||||
|
cache_config.cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize the cache.
|
# Initialize the cache.
|
||||||
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
|
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
|
||||||
@ -56,7 +64,7 @@ class CacheEngine:
|
|||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
"""Allocates KV cache on the specified device."""
|
"""Allocates KV cache on the specified device."""
|
||||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||||
num_blocks, self.block_size, self.num_heads, self.head_size)
|
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||||
kv_cache: List[torch.Tensor] = []
|
kv_cache: List[torch.Tensor] = []
|
||||||
for _ in range(self.num_layers):
|
for _ in range(self.num_layers):
|
||||||
|
|||||||
@ -53,7 +53,15 @@ class CPUModelRunner:
|
|||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.sliding_window = model_config.get_sliding_window()
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
self.attn_backend = get_attn_backend(self.model_config.dtype)
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||||
|
self.model_config.get_head_size(),
|
||||||
|
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||||
|
self.model_config.get_sliding_window(),
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Lazy initialization.
|
# Lazy initialization.
|
||||||
self.model: nn.Module # Set after init_Model
|
self.model: nn.Module # Set after init_Model
|
||||||
@ -66,7 +74,8 @@ class CPUModelRunner:
|
|||||||
vision_language_config=self.vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
parallel_config=self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
scheduler_config=self.scheduler_config)
|
scheduler_config=self.scheduler_config,
|
||||||
|
cache_config=self.cache_config)
|
||||||
|
|
||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
@ -158,7 +167,6 @@ class CPUModelRunner:
|
|||||||
decode_metadata=None,
|
decode_metadata=None,
|
||||||
block_tables=torch.tensor([]),
|
block_tables=torch.tensor([]),
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
|
||||||
)
|
)
|
||||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||||
multi_modal_input)
|
multi_modal_input)
|
||||||
@ -242,7 +250,6 @@ class CPUModelRunner:
|
|||||||
prefill_metadata=None,
|
prefill_metadata=None,
|
||||||
decode_metadata=None,
|
decode_metadata=None,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
input_tokens,
|
input_tokens,
|
||||||
|
|||||||
@ -53,7 +53,15 @@ class CPUCacheEngine:
|
|||||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||||
|
|
||||||
# Get attention backend.
|
# Get attention backend.
|
||||||
self.attn_backend = get_attn_backend(model_config.dtype)
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||||
|
self.model_config.get_head_size(),
|
||||||
|
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||||
|
self.model_config.get_sliding_window(),
|
||||||
|
self.model_config.dtype,
|
||||||
|
cache_config.cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize the cache.
|
# Initialize the cache.
|
||||||
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
|
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
|
||||||
|
|||||||
@ -235,7 +235,6 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
prefill_metadata=prefill_attn_metadata,
|
prefill_metadata=prefill_attn_metadata,
|
||||||
decode_metadata=decode_attn_metadata,
|
decode_metadata=decode_attn_metadata,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||||
|
|||||||
@ -141,10 +141,18 @@ class ModelRunner:
|
|||||||
self.graph_block_tables = np.zeros(
|
self.graph_block_tables = np.zeros(
|
||||||
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
self.attn_backend = get_attn_backend(self.model_config.dtype)
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||||
|
self.model_config.get_head_size(),
|
||||||
|
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||||
|
self.model_config.get_sliding_window(),
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Lazy initialization
|
# Lazy initialization
|
||||||
self.model: torch.nn.Module # Set after load_model
|
self.model: nn.Module # Set after load_model
|
||||||
# Set if the backend is flashinfer.
|
# Set if the backend is flashinfer.
|
||||||
self.flashinfer_workspace_buffer: torch.Tensor
|
self.flashinfer_workspace_buffer: torch.Tensor
|
||||||
# Set after load_model.
|
# Set after load_model.
|
||||||
@ -160,6 +168,7 @@ class ModelRunner:
|
|||||||
vision_language_config=self.vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
parallel_config=self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
scheduler_config=self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
|
cache_config=self.cache_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_memory_usage = m.consumed_memory
|
self.model_memory_usage = m.consumed_memory
|
||||||
@ -753,7 +762,6 @@ class ModelRunner:
|
|||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
prefill_metadata=prefill_attn_metadata,
|
prefill_metadata=prefill_attn_metadata,
|
||||||
decode_metadata=decode_attn_metadata,
|
decode_metadata=decode_attn_metadata,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata,
|
return (input_tokens, input_positions, attn_metadata,
|
||||||
@ -965,7 +973,6 @@ class ModelRunner:
|
|||||||
slot_mapping=slot_mapping[:batch_size],
|
slot_mapping=slot_mapping[:batch_size],
|
||||||
prefill_metadata=None,
|
prefill_metadata=None,
|
||||||
decode_metadata=decode_metadata,
|
decode_metadata=decode_metadata,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user