[Model] Improve olmo and olmo2 (#23228)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-08-20 20:47:05 +08:00 committed by GitHub
parent 7cd17e22d7
commit c6d80a7a96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 7 deletions

View File

@ -384,8 +384,8 @@ th {
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ |
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ |
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |

View File

@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -91,6 +91,7 @@ class OlmoAttention(nn.Module):
self.total_num_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
# Rotary embeddings.
@ -114,6 +115,7 @@ class OlmoAttention(nn.Module):
self.hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
@ -142,6 +144,7 @@ class OlmoMLP(nn.Module):
self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@ -154,6 +157,7 @@ class OlmoMLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
# Activation function.
@ -165,6 +169,7 @@ class OlmoMLP(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(
@ -197,7 +202,7 @@ class OlmoDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = OlmoMLP(config, quant_config)
self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
@ -326,10 +331,21 @@ class OlmoModel(nn.Module):
return loaded_params
class OlmoForCausalLM(nn.Module, SupportsPP):
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
"""
Extremely barebones HF model wrapper.
"""
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -33,6 +33,7 @@ from torch import nn
from transformers import Olmo2Config
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
@ -48,7 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
@ -253,6 +254,7 @@ class Olmo2DecoderLayer(nn.Module):
return hidden_states
@support_torch_compile
class Olmo2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -354,10 +356,21 @@ class Olmo2Model(nn.Module):
return loaded_params
class Olmo2ForCausalLM(nn.Module, SupportsPP):
class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
"""
Extremely barebones HF model wrapper.
"""
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()