diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index a98c5b19f56e..85d917e6d3b5 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -19,6 +19,7 @@ _MODEL_REGISTRY = { "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTJForCausalLM": GPTJForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, + "InternLMForCausalLM": InternLMForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "MPTForCausalLM": MPTForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 9e89c463593a..6c51f9ccd3d7 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -6,14 +6,24 @@ from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.model_executor.models.gpt_j import GPTJForCausalLM from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM +from vllm.model_executor.models.internlm import InternLMForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel __all__ = [ - "BaiChuanForCausalLM", "BaichuanForCausalLM", "BloomForCausalLM", - "FalconForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", - "GPTJForCausalLM", "GPTNeoXForCausalLM", "LlamaForCausalLM", - "MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel" + "BaiChuanForCausalLM", + "BaichuanForCausalLM", + "BloomForCausalLM", + "FalconForCausalLM", + "GPT2LMHeadModel", + "GPTBigCodeForCausalLM", + "GPTJForCausalLM", + "GPTNeoXForCausalLM", + "InternLMForCausalLM", + "LlamaForCausalLM", + "MPTForCausalLM", + "OPTForCausalLM", + "QWenLMHeadModel", ] diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py new file mode 100644 index 000000000000..e2fb3f2ff064 --- /dev/null +++ b/vllm/model_executor/models/internlm.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class InternLMMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_up_proj = ColumnParallelLinear(hidden_size, + 2 * intermediate_size, + bias=True, + gather_output=False, + perform_initialization=False) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=True, + input_is_parallel=True, + perform_initialization=False) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class InternLMAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + 3 * self.total_num_heads * self.head_dim, + bias=True, + gather_output=False, + perform_initialization=False, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + input_is_parallel=True, + perform_initialization=False, + ) + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + output, _ = self.o_proj(attn_output) + return output + + +class InternLMDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = InternLMAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + ) + self.mlp = InternLMMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class InternLMModel(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.embed_tokens = VocabParallelEmbedding( + vocab_size, config.hidden_size, perform_initialization=False) + self.layers = nn.ModuleList([ + InternLMDecoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class InternLMForCausalLM(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.model = InternLMModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.lm_head = ColumnParallelLinear(config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + input_metadata) + return next_tokens + + _column_parallel_weights = [ + "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", + "gate_proj.weight", "up_proj.weight" + ] + _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "rotary_emb.inv_freq" in name: + continue + + if "embed_tokens" in name or "lm_head" in name: + param = state_dict[name] + # Consider padding in the vocab size. + padded_vocab_size = (param.shape[0] * + tensor_model_parallel_world_size) + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, + loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + + is_attention_weight = False + for stride_id, att_weight_name in enumerate( + ["q_proj", "k_proj", "v_proj"]): + if att_weight_name not in name: + continue + param = state_dict[name.replace(att_weight_name, "qkv_proj")] + shard_size = param.shape[0] // 3 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank)