[Model] Add FlexOlmo model implementation (#24923)

Signed-off-by: Shane A <shanea@allenai.org>
This commit is contained in:
Shane A 2025-10-10 09:43:15 -07:00 committed by GitHub
parent b2155ed317
commit 8d2b8c0ff2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 286 additions and 46 deletions

View File

@ -363,6 +363,7 @@ th {
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ |
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `FlexOlmoForCausalLM` | FlexOlmo | `allenai/FlexOlmo-7x7B-1T`, `allenai/FlexOlmo-7x7B-1T-RT`, etc. | | ✅︎ | ✅︎ |
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -250,6 +250,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),

View File

@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only FlexOlmo model compatible with HuggingFace weights."""
from typing import Optional
import torch
from torch import nn
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM
from vllm.transformers_utils.configs import FlexOlmoConfig
logger = init_logger(__name__)
class FlexOlmoAttention(OlmoeAttention):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
hf_config = vllm_config.model_config.hf_config
assert isinstance(hf_config, FlexOlmoConfig)
self.k_norm = RMSNorm(
self.total_num_kv_heads * self.head_dim, eps=hf_config.rms_norm_eps
)
self.q_norm = RMSNorm(
self.total_num_heads * self.head_dim, eps=hf_config.rms_norm_eps
)
class FlexOlmoMoE(nn.Module):
"""A tensor-parallel MoE implementation for FlexOlmo that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
hf_config = vllm_config.model_config.hf_config
assert isinstance(hf_config, FlexOlmoConfig)
tp_size = get_tensor_model_parallel_world_size()
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
hf_config.hidden_size,
hf_config.num_experts,
bias=False,
return_bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
# Gate always runs at half / full precision for now.
self.experts = FusedMoE(
num_experts=hf_config.num_experts,
top_k=hf_config.num_experts_per_tok,
hidden_size=hf_config.hidden_size,
intermediate_size=hf_config.intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=None,
tp_size=tp_size,
prefix=f"{prefix}.experts",
)
self.top_k = hf_config.num_experts_per_tok
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
# Warning: The experts mutate the hidden state input! This messes up
# basic things like the residual stream.
final_hidden_states = self.experts(
hidden_states=hidden_states.detach().clone(),
router_logits=router_logits.float(),
)
return final_hidden_states.view(orig_shape)
class FlexOlmoDecoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
hf_config = vllm_config.model_config.hf_config
assert isinstance(hf_config, FlexOlmoConfig)
self.self_attn = FlexOlmoAttention(
vllm_config=vllm_config, prefix=f"{prefix}.self_attn"
)
self.post_attention_layernorm = RMSNorm(
hf_config.hidden_size, eps=hf_config.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
hf_config.hidden_size, eps=hf_config.rms_norm_eps
)
self.mlp = FlexOlmoMoE(vllm_config=vllm_config, prefix=f"{prefix}.mlp")
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# Attention block.
residual = hidden_states
hidden_states = self.self_attn(positions, hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + residual
# MLP block.
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, None
class FlexOlmoForCausalLM(OlmoeForCausalLM):
fall_back_to_pt_during_load = False
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = FlexOlmoDecoderLayer,
):
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)

View File

@ -17,15 +17,14 @@
from collections.abc import Iterable
from functools import partial
from itertools import islice
from typing import Any, Optional, Union
from typing import Optional, Union
import torch
from torch import nn
from transformers import OlmoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
@ -117,20 +116,21 @@ class OlmoeMoE(nn.Module):
class OlmoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
self.hidden_size = hidden_size
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
num_heads = config.num_attention_heads
num_kv_heads = config.num_key_value_heads
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
@ -145,7 +145,7 @@ class OlmoeAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@ -153,7 +153,7 @@ class OlmoeAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
@ -166,7 +166,7 @@ class OlmoeAttention(nn.Module):
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
@ -218,28 +218,15 @@ class OlmoeAttention(nn.Module):
class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
config: OlmoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = OlmoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=f"{prefix}.self_attn",
)
@ -280,12 +267,16 @@ class OlmoeDecoderLayer(nn.Module):
@support_torch_compile
class OlmoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = OlmoeDecoderLayer,
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.config = config
@ -295,9 +286,7 @@ class OlmoeModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(
config, cache_config, quant_config, prefix=prefix
),
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
@ -339,7 +328,10 @@ class OlmoeModel(nn.Module):
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@ -455,14 +447,22 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = OlmoeDecoderLayer,
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OlmoeModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type,
)
self.lm_head = ParallelLMHead(
config.vocab_size,

View File

@ -90,6 +90,7 @@ _TEXT_GENERATION_MODELS = {
"Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),

View File

@ -74,6 +74,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v3="DeepseekV3Config",
deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
kimi_vl="KimiVLConfig",
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)

View File

@ -17,6 +17,7 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
@ -45,6 +46,7 @@ __all__ = [
"DeepseekV3Config",
"DotsOCRConfig",
"EAGLEConfig",
"FlexOlmoConfig",
"RWConfig",
"JAISConfig",
"Lfm2MoeConfig",

View File

@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers.configuration_utils import PretrainedConfig
class FlexOlmoConfig(PretrainedConfig):
model_type = "flex_olmo"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=100352,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-06,
use_cache=True,
pad_token_id=100277,
bos_token_id=None,
eos_token_id=100257,
tie_word_embeddings=False,
rope_theta=500000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
num_experts_per_tok=5,
num_experts=7,
output_router_logits=False,
router_aux_loss_coef=0.01,
norm_topk_prob=False,
**kwargs,
):
if "architectures" not in kwargs:
kwargs["architectures"] = ["FlexOlmoForCausalLM"]
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.norm_topk_prob = norm_topk_prob
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]