diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7908e4238710..7308d0010690 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -384,8 +384,8 @@ th { | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 1dc4df85c1bc..01639d398126 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -91,6 +91,7 @@ class OlmoAttention(nn.Module): self.total_num_heads, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) # Rotary embeddings. @@ -114,6 +115,7 @@ class OlmoAttention(nn.Module): self.hidden_size, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) def forward( @@ -142,6 +144,7 @@ class OlmoMLP(nn.Module): self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -154,6 +157,7 @@ class OlmoMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) # Activation function. @@ -165,6 +169,7 @@ class OlmoMLP(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) def forward( @@ -197,7 +202,7 @@ class OlmoDecoderLayer(nn.Module): prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = OlmoMLP(config, quant_config) + self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp") # LayerNorm self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -326,10 +331,21 @@ class OlmoModel(nn.Module): return loaded_params -class OlmoForCausalLM(nn.Module, SupportsPP): +class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 499e6d30ed6b..66a0f9115585 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -33,6 +33,7 @@ from torch import nn from transformers import Olmo2Config from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.communication_op import tensor_model_parallel_all_gather @@ -48,7 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -253,6 +254,7 @@ class Olmo2DecoderLayer(nn.Module): return hidden_states +@support_torch_compile class Olmo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -354,10 +356,21 @@ class Olmo2Model(nn.Module): return loaded_params -class Olmo2ForCausalLM(nn.Module, SupportsPP): +class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__()