[6/N] pass whole config to inner model (#10205)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-10 22:41:46 -08:00 committed by GitHub
parent f0f2e5638e
commit f89d18ff74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
69 changed files with 681 additions and 963 deletions

View File

@ -34,7 +34,8 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
@ -364,14 +365,13 @@ class ArcticDecoderLayer(nn.Module):
@support_torch_compile
class ArcticModel(nn.Module):
def __init__(
self,
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.model = ArcticModel(config,
cache_config,
quant_config,
prefix=prefix)
self.model = ArcticModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.vocab_size,

View File

@ -253,13 +253,18 @@ class BaiChuanDecoderLayer(nn.Module):
@support_torch_compile
class BaiChuanModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config)
self.model = BaiChuanModel(vllm_config=vllm_config,
prefix=prefix,
position_embedding=position_embedding)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has a lower case 'c'.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(vllm_config, prefix, "ROPE")
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ROPE")
else: # baichuan 13b, baichuan2 13b
super().__init__(vllm_config, prefix, "ALIBI")
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has an upper case 'C'.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config, prefix, "ROPE")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ROPE")

View File

@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
logger = logging.get_logger(__name__)
@ -739,13 +741,14 @@ class BartModel(nn.Module):
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
]
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
@ -810,20 +813,16 @@ class BartModel(nn.Module):
class BartForConditionalGeneration(nn.Module):
base_model_prefix = "model"
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.config = config
self.model = BartModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.model = BartModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:

View File

@ -21,6 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .utils import maybe_prefix
class BertEmbedding(nn.Module):
@ -309,12 +311,13 @@ class BertOutput(nn.Module):
class BertModel(nn.Module):
def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder(config,
cache_config,
@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module):
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(config, cache_config, quant_config)
self.model = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,

View File

@ -23,7 +23,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
maybe_prefix, merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -517,7 +513,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

View File

@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@ -221,14 +222,13 @@ class BloomBlock(nn.Module):
@support_torch_compile
class BloomModel(nn.Module):
def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embed_dim = config.hidden_size
# Embedding + LN Embedding
@ -288,11 +288,12 @@ class BloomForCausalLM(nn.Module, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config)
self.transformer = BloomModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:

View File

@ -37,7 +37,8 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
@ -831,14 +832,13 @@ class ChameleonImageVocabularyMapping:
class ChameleonModel(nn.Module):
def __init__(
self,
config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -924,19 +924,14 @@ class ChameleonModel(nn.Module):
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.model = ChameleonModel(config, cache_config, quant_config)
self.model = ChameleonModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,

View File

@ -39,7 +39,8 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
@ -481,14 +482,13 @@ class GLMTransformer(nn.Module):
class ChatGLMModel(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
@ -600,7 +600,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
multimodal_config = vllm_config.model_config.multimodal_config
@ -611,7 +610,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.transformer = ChatGLMModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = (
self.transformer.embedding.weight)

View File

@ -28,7 +28,7 @@ from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@torch.compile
@ -253,15 +254,14 @@ class CohereDecoderLayer(nn.Module):
@support_torch_compile
class CohereModel(nn.Module):
def __init__(
self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
@ -332,14 +332,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {"embed_tokens": "input_embeddings"}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
@ -353,10 +348,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=config.logit_scale)
self.model = CohereModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.model = CohereModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

View File

@ -25,7 +25,8 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DbrxRouter(nn.Module):
@ -294,14 +295,13 @@ class DbrxBlock(nn.Module):
class DbrxModel(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
@ -357,7 +357,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
if config.tie_word_embeddings:
@ -365,7 +364,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
"tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, cache_config, quant_config)
self.transformer = DbrxModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,

View File

@ -51,11 +51,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
instead.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")

View File

@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DeepseekMLP(nn.Module):
@ -326,14 +327,13 @@ class DeepseekModel(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -383,18 +383,14 @@ class DeepseekModel(nn.Module):
class DeepseekForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = DeepseekModel(config, cache_config, quant_config)
self.model = DeepseekModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)

View File

@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DeepseekV2MLP(nn.Module):
@ -408,14 +409,13 @@ class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -479,21 +479,14 @@ class DeepseekV2Model(nn.Module):
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = DeepseekV2Model(config,
cache_config,
quant_config,
prefix="model")
self.model = DeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)

View File

@ -14,6 +14,8 @@ from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
@ -42,7 +44,8 @@ class EAGLE(nn.Module):
architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(vllm_config, prefix)
self.model = model_cls(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=getattr(self.config, "eagle_fc_bias", False))

View File

@ -29,7 +29,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
@ -54,7 +54,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class ExaoneGatedMLP(nn.Module):
@ -314,15 +315,14 @@ class ExaoneDecoderLayer(nn.Module):
@support_torch_compile
class ExaoneModel(nn.Module):
def __init__(
self,
config: ExaoneConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
@ -438,14 +438,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"c_fc_1": ("gate_up_proj", 1),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -453,11 +448,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.transformer = ExaoneModel(
config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model",
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size

View File

@ -48,7 +48,8 @@ from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
FalconConfig = Union[HF_FalconConfig, RWConfig]
@ -332,14 +333,13 @@ class FalconDecoderLayer(nn.Module):
@support_torch_compile
class FalconModel(nn.Module):
def __init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@ -408,11 +408,12 @@ class FalconForCausalLM(nn.Module, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(config, cache_config, quant_config)
self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default

View File

@ -3,13 +3,10 @@ from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
@ -23,11 +20,13 @@ from .utils import AutoWeightsLoader
class Florence2LanguageModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
@ -93,15 +92,14 @@ class Florence2LanguageModel(nn.Module):
class Florence2LanguageForConditionalGeneration(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = Florence2LanguageModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.model = Florence2LanguageModel(vllm_config=vllm_config,
prefix=prefix)
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
@ -189,17 +187,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
class Florence2ForConditionalGeneration(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration(
config=config.text_config,
cache_config=cache_config,
quant_config=quant_config)
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix,
)
@property
def sampler(self):

View File

@ -258,14 +258,13 @@ class GemmaDecoderLayer(nn.Module):
@support_torch_compile
class GemmaModel(nn.Module):
def __init__(
self,
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
@ -372,14 +371,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -389,9 +383,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.quant_config = quant_config
self.model = GemmaModel(config,
cache_config,
quant_config,
self.model = GemmaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()

View File

@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
@ -243,11 +244,7 @@ class Gemma2DecoderLayer(nn.Module):
@support_torch_compile
class Gemma2Model(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -399,13 +396,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
@ -414,7 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config)
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = get_sampler()
@ -471,14 +464,11 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config, prefix)
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,

View File

@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPT2Attention(nn.Module):
@ -184,14 +185,13 @@ class GPT2Block(nn.Module):
@support_torch_compile
class GPT2Model(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
@ -247,14 +247,12 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config,
cache_config,
quant_config,
prefix="transformer")
self.transformer = GPT2Model(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:

View File

@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -189,15 +189,14 @@ class GPTBigCodeBlock(nn.Module):
@support_torch_compile
class GPTBigCodeModel(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
assert not config.add_cross_attention
@ -265,7 +264,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -273,8 +271,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
prefix=prefix)
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:

View File

@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPTJAttention(nn.Module):
@ -177,14 +178,13 @@ class GPTJBlock(nn.Module):
@support_torch_compile
class GPTJModel(nn.Module):
def __init__(
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(
@ -236,12 +236,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config, cache_config, quant_config)
self.transformer = GPTJModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(
config.vocab_size,
config.n_embd,

View File

@ -41,7 +41,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPTNeoXAttention(nn.Module):
@ -189,14 +190,13 @@ class GPTNeoXLayer(nn.Module):
@support_torch_compile
class GPTNeoXModel(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_in = VocabParallelEmbedding(
@ -249,11 +249,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt_neox"))
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,

View File

@ -28,7 +28,7 @@ from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
@ -52,7 +52,8 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers,
maybe_prefix)
class GraniteMLP(nn.Module):
@ -257,15 +258,14 @@ class GraniteDecoderLayer(nn.Module):
@support_torch_compile
class GraniteModel(nn.Module):
def __init__(
self,
config: GraniteConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
@ -370,25 +370,17 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = GraniteModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.model = GraniteModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:

View File

@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors
from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers
from .utils import make_layers, maybe_prefix
class GraniteMoeMoE(nn.Module):
@ -247,15 +247,14 @@ class GraniteMoeDecoderLayer(nn.Module):
@support_torch_compile
class GraniteMoeModel(nn.Module):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
@ -333,25 +332,17 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = GraniteMoeModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.model = GraniteMoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -22,17 +22,15 @@ import torch.utils.checkpoint
from PIL import Image
from torch import nn
# Temporary solution for transformers below 4.46.0.
from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -48,7 +46,8 @@ from .idefics2_vision_model import (
# yapf: enable
from .interfaces import SupportsMultiModal
from .llama import LlamaModel
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -417,13 +416,13 @@ class Idefics3Connector(nn.Module):
class Idefics3Model(nn.Module):
def __init__(
self,
config: Idefics3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
@ -613,22 +612,18 @@ class Idefics3Model(nn.Module):
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.model = Idefics3Model(config, cache_config, quant_config)
self.model = Idefics3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.image_token_id = self.config.image_token_id
self.lm_head = ParallelLMHead(

View File

@ -250,14 +250,13 @@ class InternLMDecoderLayer(nn.Module):
@support_torch_compile
class InternLM2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -317,20 +316,13 @@ class InternLM2Model(nn.Module):
class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(config,
cache_config,
quant_config,
self.model = InternLM2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,

View File

@ -104,14 +104,13 @@ class InternLM2VEDecoderLayer(nn.Module):
class InternLM2VEModel(InternLM2Model):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(
@ -159,12 +158,8 @@ class InternLM2VEModel(InternLM2Model):
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__(vllm_config, prefix=prefix)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config

View File

@ -35,7 +35,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
maybe_prefix, merge_multimodal_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
@ -435,13 +435,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix="vision_model",
prefix=maybe_prefix(prefix, "vision_model"),
)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
self.mlp1 = self._init_mlp1(config)

View File

@ -44,7 +44,8 @@ from vllm.transformers_utils.configs import JAISConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class SwiGLUActivation(nn.Module):
@ -215,14 +216,13 @@ class JAISBlock(nn.Module):
@support_torch_compile
class JAISModel(nn.Module):
def __init__(
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
@ -293,11 +293,12 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config)
self.transformer = JAISModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:

View File

@ -7,7 +7,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -29,6 +29,7 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -258,14 +259,14 @@ ALL_DECODER_LAYER_TYPES = {
class JambaModel(nn.Module):
def __init__(
self,
config: JambaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
@ -348,14 +349,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
@ -364,10 +360,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = JambaModel(config,
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config)
self.model = JambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -28,7 +28,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
@ -271,15 +271,14 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
@ -492,24 +491,16 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"norm": "model.norm"
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config=lora_config,
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
@ -652,23 +643,12 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config,
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,

View File

@ -32,7 +32,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
maybe_prefix, merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict):
@ -282,7 +282,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config,
quant_config,
require_post_norm=False,
prefix="vision_tower")
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
@ -291,7 +291,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

View File

@ -31,7 +31,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model)
init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TypedDict):
@ -296,7 +296,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
config,
quant_config,
require_post_norm=False,
prefix="vision_tower")
prefix=maybe_prefix(prefix, "vision_tower"))
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
@ -307,7 +307,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
# The same model class supports both language generation and embedding
# because the architecture name is the same

View File

@ -29,7 +29,7 @@ from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
maybe_prefix, merge_multimodal_embeddings)
# For profile run
_MAX_FRAMES_PER_VIDEO = 32
@ -267,7 +267,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
config,
quant_config,
require_post_norm=False,
prefix="vision_tower")
prefix=maybe_prefix(prefix, "vision_tower"))
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
@ -276,7 +276,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)

View File

@ -35,7 +35,7 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
maybe_prefix, merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
@ -418,12 +418,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
config,
quant_config,
require_post_norm=False,
prefix="vision_tower")
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))

View File

@ -6,7 +6,7 @@ from torch import nn
from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -26,6 +26,8 @@ from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from .utils import maybe_prefix
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -73,14 +75,14 @@ class MambaDecoderLayer(nn.Module):
class MambaModel(nn.Module):
def __init__(
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
@ -130,14 +132,9 @@ class MambaModel(nn.Module):
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
@ -146,10 +143,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.backbone = MambaModel(config,
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config)
self.backbone = MambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "backbone"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MiniCPMMoE(nn.Module):
@ -351,15 +352,14 @@ class MiniCPMDecoderLayer(nn.Module):
@support_torch_compile
class MiniCPMModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
@ -461,24 +461,22 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.prefix = prefix
self.vllm_config = vllm_config
self.config = config
self.lora_config = lora_config
self.cache_config = cache_config
self.quant_config = quant_config
self.num_experts = getattr(self.config, "num_experts", 0)
self._init_model()
self._init_model(vllm_config=vllm_config, prefix=prefix)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
@ -502,11 +500,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self):
self.model = MiniCPMModel(config=self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
lora_config=self.lora_config)
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPMModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
def forward(
self,

View File

@ -28,7 +28,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -40,7 +40,7 @@ from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM,
MiniCPMModel)
from .utils import make_layers
from .utils import make_layers, maybe_prefix
class MiniCPM3Attention(nn.Module):
@ -238,8 +238,6 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
# `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM
def _init_model(self):
self.model = MiniCPM3Model(config=self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
lora_config=self.lora_config)
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPM3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

View File

@ -34,7 +34,7 @@ from transformers import PretrainedConfig
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -59,7 +59,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter
from .utils import is_pp_missing_parameter, maybe_prefix
_KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
@ -390,7 +390,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
):
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
@ -401,11 +400,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config,
cache_config,
quant_config,
prefix="llm")
self.vpm = self.init_vision_module(config, quant_config, prefix="vpm")
self.llm = self.init_llm(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "llm"))
self.vpm = self.init_vision_module(config,
quant_config,
prefix=maybe_prefix(prefix, "vpm"))
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
@ -414,13 +413,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.resampler = self.init_resampler(self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix="resampler")
prefix=maybe_prefix(
prefix, "resampler"))
self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix="llm.lm_head")
prefix=maybe_prefix(
prefix, "llm.lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
@ -661,9 +662,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
raise NotImplementedError
@ -711,16 +710,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix),
name="model")
def init_vision_module(
@ -875,15 +868,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix),
name="model")
def init_vision_module(
@ -1022,16 +1010,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix),
name="model")
def init_vision_module(
@ -1151,4 +1133,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(vllm_config, prefix=prefix)
return instance_class(vllm_config=vllm_config, prefix=prefix)

View File

@ -28,7 +28,7 @@ from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MixtralMoE(nn.Module):
@ -248,15 +249,14 @@ class MixtralDecoderLayer(nn.Module):
@support_torch_compile
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
@ -332,24 +332,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = MixtralModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MixtralMLP(nn.Module):
@ -293,14 +294,13 @@ class MixtralDecoderLayer(nn.Module):
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -350,18 +350,14 @@ class MixtralModel(nn.Module):
class MixtralForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(config, cache_config, quant_config)
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)

View File

@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs)
@ -56,6 +56,7 @@ from vllm.utils import is_list_of
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import maybe_prefix
logger = init_logger(__name__)
MLLAMA_IMAGE_TOKEN_ID = 128256
@ -939,15 +940,13 @@ class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model"
def __init__(
self,
config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
@ -1029,18 +1028,14 @@ class MllamaForCausalLM(nn.Module):
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
]
def __init__(
self,
config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config,
cache_config,
quant_config,
self.model = MllamaTextModel(vllm_config=vllm_config,
prefix=f"{prefix}.model")
self.lm_head = ParallelLMHead(
config.vocab_size,
@ -1108,14 +1103,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
@ -1127,12 +1117,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.vision_model = MllamaVisionModel(config.vision_config,
quant_config,
prefix="vision_model")
prefix=maybe_prefix(
prefix, "vision_model"))
self.language_model = MllamaForCausalLM(
config.text_config,
cache_config=cache_config,
quant_config=quant_config,
prefix="language_model",
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim,
@ -1140,7 +1129,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
bias=True,
quant_config=quant_config,
gather_output=True,
prefix="multi_modal_projector",
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)

View File

@ -44,7 +44,8 @@ from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
@ -716,14 +717,13 @@ class MolmoVisionBackbone(nn.Module):
@support_torch_compile
class MolmoModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embedding_size = config.embedding_size or config.vocab_size
@ -1024,14 +1024,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
@ -1040,7 +1035,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(config, vision_config,
quant_config)
self.model = MolmoModel(config, cache_config, quant_config)
self.model = MolmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if self.config.weight_tying:
self.lm_head = self.model.transformer.wte

View File

@ -26,7 +26,8 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def _get_alibi_slopes(
@ -207,14 +208,13 @@ class MPTBlock(nn.Module):
@support_torch_compile
class MPTModel(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
@ -267,20 +267,16 @@ class MPTModel(nn.Module):
class MPTForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
assert config.tie_word_embeddings
self.quant_config = quant_config
self.transformer = MPTModel(config, cache_config, quant_config)
self.transformer = MPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "transformer"))
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()

View File

@ -27,7 +27,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -47,7 +47,8 @@ from vllm.transformers_utils.configs import NemotronConfig
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# The architecture is pretty similar to Llama, with these changes:
# - There is no gate_proj, just up_proj
@ -293,15 +294,14 @@ class NemotronDecoderLayer(nn.Module):
@support_torch_compile
class NemotronModel(nn.Module):
def __init__(
self,
config: NemotronConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
@ -401,14 +401,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"v_proj": ("qkv_proj", 2),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
assert isinstance(config, NemotronConfig)
@ -416,11 +411,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config
self.lora_config = lora_config
self.model = NemotronModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.model = NemotronModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:

View File

@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OlmoAttention(nn.Module):
@ -224,12 +225,13 @@ class OlmoDecoderLayer(nn.Module):
@support_torch_compile
class OlmoModel(nn.Module):
def __init__(self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
@ -291,17 +293,13 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
Extremely barebones HF model wrapper.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.model = OlmoModel(config, cache_config, quant_config)
self.model = OlmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:

View File

@ -38,7 +38,8 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OlmoeMoE(nn.Module):
@ -243,14 +244,13 @@ class OlmoeDecoderLayer(nn.Module):
@support_torch_compile
class OlmoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -309,18 +309,14 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OlmoeModel(config, cache_config, quant_config)
self.model = OlmoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)

View File

@ -293,14 +293,13 @@ class OPTDecoder(nn.Module):
@support_torch_compile
class OPTModel(nn.Module):
def __init__(
self,
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.decoder = OPTDecoder(config,
cache_config,
quant_config,
@ -342,21 +341,14 @@ class OPTForCausalLM(nn.Module, SupportsPP):
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config,
cache_config,
quant_config,
self.model = OPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens

View File

@ -29,7 +29,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OrionMLP(nn.Module):
@ -208,14 +209,13 @@ class OrionDecoderLayer(nn.Module):
@support_torch_compile
class OrionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -268,18 +268,14 @@ class OrionModel(nn.Module):
class OrionForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OrionModel(config, cache_config, quant_config)
self.model = OrionModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)

View File

@ -20,7 +20,7 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -131,11 +131,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
@ -145,7 +141,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
prefix="vision_tower")
prefix=maybe_prefix(
prefix, "vision_tower"))
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim)
@ -155,7 +152,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale

View File

@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PersimmonMLP(nn.Module):
@ -212,12 +213,13 @@ class PersimmonDecoderLayer(nn.Module):
@support_torch_compile
class PersimmonModel(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
@ -265,20 +267,13 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.model = PersimmonModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.model = PersimmonModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False)

View File

@ -60,7 +60,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PhiAttention(nn.Module):
@ -196,12 +197,13 @@ class PhiLayer(nn.Module):
@support_torch_compile
class PhiModel(nn.Module):
def __init__(self,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
@ -277,14 +279,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
@ -294,7 +291,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config
self.model = PhiModel(config, cache_config, quant_config)
self.model = PhiModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,

View File

@ -24,7 +24,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def load_column_parallel_weight(param: torch.nn.Parameter,
@ -299,14 +300,13 @@ class Phi3SmallDecoderLayer(nn.Module):
class Phi3SmallModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
@ -363,18 +363,14 @@ class Phi3SmallModel(nn.Module):
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
_tied_weights_keys = ["lm_head.weight"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Phi3SmallModel(config, cache_config, quant_config)
self.model = Phi3SmallModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vocab_size = config.vocab_size
self.mup_width_multiplier = config.mup_width_multiplier
self.lm_head = ParallelLMHead(

View File

@ -45,7 +45,7 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -525,11 +525,7 @@ def input_processor_for_phi3v(ctx: InputContext,
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
@ -544,12 +540,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
prefix="model.embed_tokens",
prefix=maybe_prefix(prefix, "model.embed_tokens"),
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config, quant_config, prefix="model.vision_embed_tokens")
config,
quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"

View File

@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PhiMoEConfig(PretrainedConfig):
@ -432,15 +433,14 @@ class PhiMoEDecoderLayer(nn.Module):
@support_torch_compile
class PhiMoEModel(nn.Module):
def __init__(
self,
config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
@ -529,23 +529,15 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = PhiMoEModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.model = PhiMoEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -38,7 +38,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model
from .utils import init_vllm_registered_model, maybe_prefix
try:
from xformers import ops as xops
@ -152,11 +152,7 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
@ -176,7 +172,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter(

View File

@ -50,7 +50,8 @@ from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
@ -552,14 +553,13 @@ class QWenBlock(nn.Module):
@support_torch_compile
class QWenModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
@ -865,20 +865,17 @@ def dummy_data_for_qwen(
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
self.transformer = QWenModel(config, cache_config, quant_config)
self.transformer = QWenModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)

View File

@ -240,14 +240,13 @@ class Qwen2DecoderLayer(nn.Module):
@support_torch_compile
class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -403,11 +402,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -429,9 +424,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(config,
cache_config,
quant_config,
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings:

View File

@ -264,14 +264,9 @@ def input_mapper_for_qwen2_audio(
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
@ -283,8 +278,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self.quant_config = quant_config
self.language_model = Qwen2Model(config.text_config, cache_config,
quant_config)
self.language_model = Qwen2Model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix)
self.unpadded_vocab_size = config.text_config.vocab_size
if config.text_config.tie_word_embeddings:
self.lm_head = self.language_model.embed_tokens

View File

@ -17,7 +17,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .utils import AutoWeightsLoader
from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2ForSequenceClassification(nn.Module):
@ -43,11 +43,7 @@ class Qwen2ForSequenceClassification(nn.Module):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -70,7 +66,8 @@ class Qwen2ForSequenceClassification(nn.Module):
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config)
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,

View File

@ -54,7 +54,8 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class Qwen2MoeMLP(nn.Module):
@ -315,14 +316,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
@support_torch_compile
class Qwen2MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -377,18 +377,14 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.model = Qwen2MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)

View File

@ -18,7 +18,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader
from .utils import AutoWeightsLoader, maybe_prefix
class ReLU(nn.Module):
@ -55,11 +55,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -82,7 +78,8 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config)
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.score = nn.Sequential(
ColumnParallelLinear(config.hidden_size,

View File

@ -70,7 +70,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__)
@ -966,11 +966,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -986,13 +982,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix="visual",
prefix=maybe_prefix(prefix, "visual"),
)
self.model = Qwen2Model(config,
cache_config,
quant_config,
prefix="model")
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
@ -1001,7 +995,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix="lm_head")
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()

View File

@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class SolarMLP(nn.Module):
@ -266,15 +267,14 @@ class SolarDecoderLayer(nn.Module):
@support_torch_compile
class SolarModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
@ -409,25 +409,17 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = SolarModel(
config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model",
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size

View File

@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class StablelmMLP(nn.Module):
@ -193,12 +194,13 @@ class StablelmDecoderLayer(nn.Module):
class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '') -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
@ -245,18 +247,14 @@ class StableLMEpochModel(nn.Module):
class StablelmForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = StableLMEpochModel(config, cache_config, quant_config)
self.model = StableLMEpochModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)

View File

@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class Starcoder2Attention(nn.Module):
@ -195,12 +196,13 @@ class Starcoder2DecoderLayer(nn.Module):
@support_torch_compile
class Starcoder2Model(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -245,19 +247,13 @@ class Starcoder2Model(nn.Module):
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.model = Starcoder2Model(config,
cache_config,
quant_config=quant_config)
self.model = Starcoder2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:

View File

@ -34,7 +34,7 @@ from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002
@ -339,11 +339,7 @@ class ModifiedWhisperEncoder(WhisperEncoder):
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
@ -354,6 +350,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
self.secondary_weights.append(
DefaultModelLoader.Source(
model_or_path=config.audio_model_id,
@ -362,8 +360,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
))
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, vllm_config, prefix="language_model")
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
if config.text_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id,
revision=None,

View File

@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class XverseMLP(nn.Module):
@ -223,11 +224,7 @@ class XverseDecoderLayer(nn.Module):
@support_torch_compile
class XverseModel(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -315,15 +312,10 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -331,7 +323,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.quant_config = quant_config
self.model = XverseModel(config, cache_config, quant_config)
self.model = XverseModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)