From 89e08d6d180c76019daa5aa1bbf7759dfaedde2e Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 12 Sep 2025 20:26:21 -0700 Subject: [PATCH] [Model] Add Olmo3 model implementation (#24534) Signed-off-by: Shane A Co-authored-by: Isotr0py --- docs/models/supported_models.md | 1 + tests/models/registry.py | 1 + vllm/model_executor/models/olmo2.py | 42 +++++++---- vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/olmo3.py | 80 +++++++++++++++++++++ 7 files changed, 114 insertions(+), 14 deletions(-) create mode 100644 vllm/transformers_utils/configs/olmo3.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index db3dd2c252c76..6295a2aa8dc2f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -389,6 +389,7 @@ th { | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 0c77ec5ef13c3..b268bf12a3f30 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -301,6 +301,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), + "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"}), diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index bccd1b87043a5..3e4c580a11211 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -52,10 +52,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( - AutoWeightsLoader, is_pp_missing_parameter, + AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Olmo3Config class Olmo2Attention(nn.Module): @@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) hidden_size = self.config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() @@ -111,14 +112,14 @@ class Olmo2Attention(nn.Module): self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - # Rotary embeddings. - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, # type: ignore - ) self.scaling = self.head_dim**-0.5 + + layer_idx = extract_layer_index(prefix) + sliding_window = None + if ((layer_types := getattr(self.config, "layer_types", None)) + is not None and layer_types[layer_idx] == "sliding_attention"): + sliding_window = self.config.sliding_window + self.attn = Attention( self.num_heads, self.head_dim, @@ -126,7 +127,20 @@ class Olmo2Attention(nn.Module): num_kv_heads=self.num_kv_heads, cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, - prefix=prefix, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + # Rotary embeddings. Rope scaling is only applied on full attention + # layers. + self.rope_scaling = (self.config.rope_scaling + if sliding_window is None else None) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + rope_scaling=self.rope_scaling, ) # Attention output projection. @@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) hidden_size = config.hidden_size intermediate_size = config.intermediate_size @@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) # Attention block. self.self_attn = Olmo2Attention(vllm_config=vllm_config, prefix=f"{prefix}.self_attn") @@ -261,7 +275,7 @@ class Olmo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) self.config = config self.model = Olmo2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7d7654e846e1c..85759df369850 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -120,6 +120,7 @@ _TEXT_GENERATION_MODELS = { "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), + "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 2852d16ec53f4..a76f9504eccef 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -75,6 +75,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( eagle="EAGLEConfig", speculators="SpeculatorsConfig", nemotron="NemotronConfig", + olmo3="Olmo3Config", ovis="OvisConfig", ultravox="UltravoxConfig", step3_vl="Step3VLConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index cdae59ccc24e0..ca0d5def760a8 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config +from vllm.transformers_utils.configs.olmo3 import Olmo3Config from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig @@ -45,6 +46,7 @@ __all__ = [ "NemotronConfig", "NemotronHConfig", "Nemotron_Nano_VL_Config", + "Olmo3Config", "OvisConfig", "SpeculatorsConfig", "UltravoxConfig", diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py new file mode 100644 index 0000000000000..874507db43a7f --- /dev/null +++ b/vllm/transformers_utils/configs/olmo3.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class Olmo3Config(PretrainedConfig): + + model_type = "olmo3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + sliding_window=4096, + layer_types=None, + **kwargs, + ): + # This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM + # in vLLM. + if "architectures" not in kwargs: + kwargs["architectures"] = ["Olmo2ForCausalLM"] + elif "Olmo3ForCausalLM" in kwargs["architectures"]: + kwargs["architectures"].remove("Olmo3ForCausalLM") + kwargs["architectures"].append("Olmo2ForCausalLM") + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + self.sliding_window = sliding_window + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" + for i in range(self.num_hidden_layers) + ]