mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 10:24:28 +08:00
[Model] Add FlexOlmo model implementation (#24923)
Signed-off-by: Shane A <shanea@allenai.org>
This commit is contained in:
parent
b2155ed317
commit
8d2b8c0ff2
@ -363,6 +363,7 @@ th {
|
||||
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `FlexOlmoForCausalLM` | FlexOlmo | `allenai/FlexOlmo-7x7B-1T`, `allenai/FlexOlmo-7x7B-1T-RT`, etc. | | ✅︎ | ✅︎ |
|
||||
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -250,6 +250,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
|
||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
|
||||
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||
|
||||
157
vllm/model_executor/models/flex_olmo.py
Normal file
157
vllm/model_executor/models/flex_olmo.py
Normal file
@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only FlexOlmo model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM
|
||||
from vllm.transformers_utils.configs import FlexOlmoConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlexOlmoAttention(OlmoeAttention):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
assert isinstance(hf_config, FlexOlmoConfig)
|
||||
|
||||
self.k_norm = RMSNorm(
|
||||
self.total_num_kv_heads * self.head_dim, eps=hf_config.rms_norm_eps
|
||||
)
|
||||
self.q_norm = RMSNorm(
|
||||
self.total_num_heads * self.head_dim, eps=hf_config.rms_norm_eps
|
||||
)
|
||||
|
||||
|
||||
class FlexOlmoMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation for FlexOlmo that shards each expert
|
||||
across all ranks.
|
||||
|
||||
Each expert's weights are sharded across all ranks and a fused MoE
|
||||
kernel is used for the forward pass, and finally we reduce the outputs
|
||||
across ranks.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
assert isinstance(hf_config, FlexOlmoConfig)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(
|
||||
hf_config.hidden_size,
|
||||
hf_config.num_experts,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.experts = FusedMoE(
|
||||
num_experts=hf_config.num_experts,
|
||||
top_k=hf_config.num_experts_per_tok,
|
||||
hidden_size=hf_config.hidden_size,
|
||||
intermediate_size=hf_config.intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=False,
|
||||
quant_config=None,
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = hf_config.num_experts_per_tok
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
# Warning: The experts mutate the hidden state input! This messes up
|
||||
# basic things like the residual stream.
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states.detach().clone(),
|
||||
router_logits=router_logits.float(),
|
||||
)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class FlexOlmoDecoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
assert isinstance(hf_config, FlexOlmoConfig)
|
||||
|
||||
self.self_attn = FlexOlmoAttention(
|
||||
vllm_config=vllm_config, prefix=f"{prefix}.self_attn"
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
hf_config.hidden_size, eps=hf_config.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = RMSNorm(
|
||||
hf_config.hidden_size, eps=hf_config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.mlp = FlexOlmoMoE(vllm_config=vllm_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Attention block.
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# MLP block.
|
||||
residual = hidden_states
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states, None
|
||||
|
||||
|
||||
class FlexOlmoForCausalLM(OlmoeForCausalLM):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = FlexOlmoDecoderLayer,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
@ -17,15 +17,14 @@
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import OlmoeConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
@ -117,20 +116,21 @@ class OlmoeMoE(nn.Module):
|
||||
|
||||
|
||||
class OlmoeAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 4096,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
num_kv_heads = config.num_key_value_heads
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
@ -145,7 +145,7 @@ class OlmoeAttention(nn.Module):
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@ -153,7 +153,7 @@ class OlmoeAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
@ -166,7 +166,7 @@ class OlmoeAttention(nn.Module):
|
||||
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
@ -218,28 +218,15 @@ class OlmoeAttention(nn.Module):
|
||||
|
||||
|
||||
class OlmoeDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OlmoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
||||
|
||||
self.self_attn = OlmoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
@ -280,12 +267,16 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
|
||||
@support_torch_compile
|
||||
class OlmoeModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = OlmoeDecoderLayer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
@ -295,9 +286,7 @@ class OlmoeModel(nn.Module):
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: OlmoeDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix
|
||||
),
|
||||
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
@ -339,7 +328,10 @@ class OlmoeModel(nn.Module):
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
@ -455,14 +447,22 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = OlmoeDecoderLayer,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = OlmoeModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
|
||||
@ -90,6 +90,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
|
||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
|
||||
"FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
|
||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||
|
||||
@ -74,6 +74,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
deepseek_vl_v2="DeepseekVLV2Config",
|
||||
deepseek_v3="DeepseekV3Config",
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
flex_olmo="FlexOlmoConfig",
|
||||
kimi_vl="KimiVLConfig",
|
||||
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
|
||||
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
"DeepseekV3Config",
|
||||
"DotsOCRConfig",
|
||||
"EAGLEConfig",
|
||||
"FlexOlmoConfig",
|
||||
"RWConfig",
|
||||
"JAISConfig",
|
||||
"Lfm2MoeConfig",
|
||||
|
||||
77
vllm/transformers_utils/configs/flex_olmo.py
Normal file
77
vllm/transformers_utils/configs/flex_olmo.py
Normal file
@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class FlexOlmoConfig(PretrainedConfig):
|
||||
model_type = "flex_olmo"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=100352,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-06,
|
||||
use_cache=True,
|
||||
pad_token_id=100277,
|
||||
bos_token_id=None,
|
||||
eos_token_id=100257,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=500000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
num_experts_per_tok=5,
|
||||
num_experts=7,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.01,
|
||||
norm_topk_prob=False,
|
||||
**kwargs,
|
||||
):
|
||||
if "architectures" not in kwargs:
|
||||
kwargs["architectures"] = ["FlexOlmoForCausalLM"]
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
Loading…
x
Reference in New Issue
Block a user