mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 14:26:07 +08:00
143 lines
5.4 KiB
Python
143 lines
5.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from itertools import islice
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import get_pp_group
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.models.internlm2 import (
|
|
InternLM2Attention,
|
|
InternLM2ForCausalLM,
|
|
InternLM2MLP,
|
|
InternLM2Model,
|
|
)
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
|
|
class InternLM2VEDecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
|
self.attention = InternLM2Attention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attention",
|
|
)
|
|
self.feed_forward = InternLM2MLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.feed_forward",
|
|
)
|
|
self.feed_forward_ve = InternLM2MLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.feed_forward_ve",
|
|
)
|
|
self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
visual_token_mask: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.attention_norm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.attention_norm(hidden_states, residual)
|
|
hidden_states = self.attention(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.ffn_norm(hidden_states, residual)
|
|
if visual_token_mask is not None and visual_token_mask.any():
|
|
visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool()
|
|
text_token_mask = ~visual_token_mask
|
|
hidden_states[visual_token_mask] = self.feed_forward_ve(
|
|
hidden_states[visual_token_mask].reshape(-1, self.hidden_size)
|
|
).flatten()
|
|
if text_token_mask.any():
|
|
hidden_states[text_token_mask] = self.feed_forward(
|
|
hidden_states[text_token_mask].reshape(-1, self.hidden_size)
|
|
).flatten()
|
|
else:
|
|
hidden_states = self.feed_forward(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class InternLM2VEModel(InternLM2Model):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__(
|
|
vllm_config=vllm_config, prefix=prefix, layer_type=InternLM2VEDecoderLayer
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
visual_token_mask: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.tok_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
residual,
|
|
visual_token_mask=visual_token_mask,
|
|
)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors(
|
|
{"hidden_states": hidden_states, "residual": residual}
|
|
)
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__(
|
|
vllm_config=vllm_config, prefix=prefix, model_type=InternLM2VEModel
|
|
)
|