mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 02:44:28 +08:00
[Model] Add Olmo3 model implementation (#24534)
Signed-off-by: Shane A <shanea@allenai.org> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
7f2ea7074e
commit
89e08d6d18
@ -389,6 +389,7 @@ th {
|
||||
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -301,6 +301,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
|
||||
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
|
||||
"Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"),
|
||||
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
|
||||
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m",
|
||||
{"1b": "facebook/opt-iml-max-1.3b"}),
|
||||
|
||||
@ -52,10 +52,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader, is_pp_missing_parameter,
|
||||
AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Olmo3Config
|
||||
|
||||
|
||||
class Olmo2Attention(nn.Module):
|
||||
@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
assert isinstance(self.config, Olmo2Config)
|
||||
assert isinstance(self.config, (Olmo2Config, Olmo3Config))
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -111,14 +112,14 @@ class Olmo2Attention(nn.Module):
|
||||
self.q_norm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
# Rotary embeddings.
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta, # type: ignore
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
sliding_window = None
|
||||
if ((layer_types := getattr(self.config, "layer_types", None))
|
||||
is not None and layer_types[layer_idx] == "sliding_attention"):
|
||||
sliding_window = self.config.sliding_window
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
@ -126,7 +127,20 @@ class Olmo2Attention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=prefix,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
# Rotary embeddings. Rope scaling is only applied on full attention
|
||||
# layers.
|
||||
self.rope_scaling = (self.config.rope_scaling
|
||||
if sliding_window is None else None)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta, # type: ignore
|
||||
rope_scaling=self.rope_scaling,
|
||||
)
|
||||
|
||||
# Attention output projection.
|
||||
@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
assert isinstance(config, Olmo2Config)
|
||||
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
assert isinstance(config, Olmo2Config)
|
||||
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||
# Attention block.
|
||||
self.self_attn = Olmo2Attention(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
@ -261,7 +275,7 @@ class Olmo2Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
assert isinstance(self.config, Olmo2Config)
|
||||
assert isinstance(self.config, (Olmo2Config, Olmo3Config))
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
assert isinstance(config, Olmo2Config)
|
||||
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||
self.config = config
|
||||
self.model = Olmo2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
@ -120,6 +120,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
|
||||
@ -75,6 +75,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
eagle="EAGLEConfig",
|
||||
speculators="SpeculatorsConfig",
|
||||
nemotron="NemotronConfig",
|
||||
olmo3="Olmo3Config",
|
||||
ovis="OvisConfig",
|
||||
ultravox="UltravoxConfig",
|
||||
step3_vl="Step3VLConfig",
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
||||
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
||||
from vllm.transformers_utils.configs.olmo3 import Olmo3Config
|
||||
from vllm.transformers_utils.configs.ovis import OvisConfig
|
||||
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
|
||||
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
"NemotronConfig",
|
||||
"NemotronHConfig",
|
||||
"Nemotron_Nano_VL_Config",
|
||||
"Olmo3Config",
|
||||
"OvisConfig",
|
||||
"SpeculatorsConfig",
|
||||
"UltravoxConfig",
|
||||
|
||||
80
vllm/transformers_utils/configs/olmo3.py
Normal file
80
vllm/transformers_utils/configs/olmo3.py
Normal file
@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Olmo3Config(PretrainedConfig):
|
||||
|
||||
model_type = "olmo3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50304,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
pad_token_id=1,
|
||||
bos_token_id=None,
|
||||
eos_token_id=50279,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
rms_norm_eps=1e-5,
|
||||
sliding_window=4096,
|
||||
layer_types=None,
|
||||
**kwargs,
|
||||
):
|
||||
# This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM
|
||||
# in vLLM.
|
||||
if "architectures" not in kwargs:
|
||||
kwargs["architectures"] = ["Olmo2ForCausalLM"]
|
||||
elif "Olmo3ForCausalLM" in kwargs["architectures"]:
|
||||
kwargs["architectures"].remove("Olmo3ForCausalLM")
|
||||
kwargs["architectures"].append("Olmo2ForCausalLM")
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if (i + 1) % 4 != 0 else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
Loading…
x
Reference in New Issue
Block a user