From 1a95f10ee7d2ffa538a6d210b53bf363e039feee Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 8 Nov 2024 22:17:28 -0800 Subject: [PATCH] [5/N] pass the whole config to model (#9983) Signed-off-by: youkaichao --- vllm/model_executor/model_loader/loader.py | 100 ++---------------- .../model_executor/model_loader/tensorizer.py | 15 +-- vllm/model_executor/models/arctic.py | 16 +-- vllm/model_executor/models/baichuan.py | 37 +++---- vllm/model_executor/models/bart.py | 12 +-- vllm/model_executor/models/bert.py | 12 ++- vllm/model_executor/models/blip2.py | 20 ++-- vllm/model_executor/models/bloom.py | 10 +- vllm/model_executor/models/chameleon.py | 12 ++- vllm/model_executor/models/chatglm.py | 15 +-- vllm/model_executor/models/commandr.py | 12 ++- vllm/model_executor/models/dbrx.py | 10 +- vllm/model_executor/models/decilm.py | 18 ++-- vllm/model_executor/models/deepseek.py | 10 +- vllm/model_executor/models/deepseek_v2.py | 10 +- vllm/model_executor/models/eagle.py | 7 +- vllm/model_executor/models/exaone.py | 12 ++- vllm/model_executor/models/falcon.py | 10 +- vllm/model_executor/models/florence2.py | 10 +- vllm/model_executor/models/fuyu.py | 15 ++- vllm/model_executor/models/gemma.py | 11 +- vllm/model_executor/models/gemma2.py | 27 ++--- vllm/model_executor/models/gpt2.py | 10 +- vllm/model_executor/models/gpt_bigcode.py | 12 ++- vllm/model_executor/models/gpt_j.py | 10 +- vllm/model_executor/models/gpt_neox.py | 10 +- vllm/model_executor/models/granite.py | 12 ++- vllm/model_executor/models/granitemoe.py | 12 ++- vllm/model_executor/models/idefics3.py | 13 ++- vllm/model_executor/models/interfaces_base.py | 24 +---- vllm/model_executor/models/internlm2.py | 9 +- vllm/model_executor/models/internlm2_ve.py | 9 +- vllm/model_executor/models/internvl.py | 15 ++- vllm/model_executor/models/jais.py | 10 +- vllm/model_executor/models/jamba.py | 14 +-- vllm/model_executor/models/llama.py | 30 ++++-- vllm/model_executor/models/llava.py | 15 ++- vllm/model_executor/models/llava_next.py | 17 ++- .../model_executor/models/llava_next_video.py | 15 ++- vllm/model_executor/models/llava_onevision.py | 15 ++- vllm/model_executor/models/mamba.py | 14 +-- vllm/model_executor/models/medusa.py | 5 +- vllm/model_executor/models/minicpm.py | 12 ++- vllm/model_executor/models/minicpmv.py | 48 ++++----- vllm/model_executor/models/mixtral.py | 13 +-- vllm/model_executor/models/mixtral_quant.py | 10 +- vllm/model_executor/models/mllama.py | 15 +-- vllm/model_executor/models/molmo.py | 16 +-- vllm/model_executor/models/mpt.py | 12 ++- vllm/model_executor/models/nemotron.py | 13 +-- vllm/model_executor/models/olmo.py | 14 ++- vllm/model_executor/models/olmoe.py | 10 +- vllm/model_executor/models/opt.py | 12 ++- vllm/model_executor/models/orion.py | 10 +- vllm/model_executor/models/paligemma.py | 30 +++--- vllm/model_executor/models/persimmon.py | 14 ++- vllm/model_executor/models/phi.py | 15 +-- vllm/model_executor/models/phi3_small.py | 13 +-- vllm/model_executor/models/phi3v.py | 23 ++-- vllm/model_executor/models/phimoe.py | 13 +-- vllm/model_executor/models/pixtral.py | 20 ++-- vllm/model_executor/models/qwen.py | 31 +++--- vllm/model_executor/models/qwen2.py | 14 +-- vllm/model_executor/models/qwen2_audio.py | 21 ++-- vllm/model_executor/models/qwen2_cls.py | 20 ++-- vllm/model_executor/models/qwen2_moe.py | 10 +- vllm/model_executor/models/qwen2_rm.py | 19 ++-- vllm/model_executor/models/qwen2_vl.py | 19 ++-- vllm/model_executor/models/solar.py | 13 +-- vllm/model_executor/models/stablelm.py | 10 +- vllm/model_executor/models/starcoder2.py | 14 ++- vllm/model_executor/models/ultravox.py | 20 ++-- vllm/model_executor/models/utils.py | 27 ++--- vllm/model_executor/models/xverse.py | 22 ++-- vllm/plugins/__init__.py | 12 --- 75 files changed, 583 insertions(+), 654 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 464915248c9ad..8d3024534734b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,8 +9,7 @@ import math import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple, - Type, cast) +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast import gguf import huggingface_hub @@ -18,20 +17,17 @@ import numpy as np import torch from huggingface_hub import HfApi, hf_hub_download from torch import nn -from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers import AutoModelForCausalLM from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PoolerConfig, SchedulerConfig, VllmConfig) +from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, + VllmConfig) from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) @@ -43,8 +39,6 @@ from vllm.model_executor.model_loader.weight_utils import ( get_gguf_extra_tensor_names, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models import (has_inner_state, supports_lora, - supports_multimodal) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -94,85 +88,11 @@ def device_loading_context(module: torch.nn.Module, logger = init_logger(__name__) -def _get_model_initialization_kwargs( - model_class: Type[nn.Module], - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig] = None, - pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]: - """Get extra kwargs for model initialization.""" - extra_kwargs: Dict[str, Any] = {} - - if supports_lora(model_class): - # lora_config=None is used to disable LoRA - extra_kwargs["lora_config"] = lora_config - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") - - if supports_multimodal(model_class): - assert multimodal_config is not None - - extra_kwargs["multimodal_config"] = multimodal_config - - if has_inner_state(model_class) and scheduler_config: - extra_kwargs["scheduler_config"] = scheduler_config - if pooler_config: - extra_kwargs["pooler_config"] = pooler_config - return extra_kwargs - - -def build_model(model_class: Type[nn.Module], - vllm_config: Optional[VllmConfig], - hf_config: PretrainedConfig, - cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], - *, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig], - prefix: Optional[str] = None, - pooler_config: Optional[PoolerConfig] = None) -> nn.Module: - extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, - multimodal_config, - scheduler_config, - pooler_config) - if prefix: - extra_kwargs["prefix"] = prefix - - # TODO: unify all the module initialization code - # to only take the `VllmConfig` object as input - from vllm.plugins import set_vllm_config - set_vllm_config(vllm_config) - - return model_class(config=hf_config, - cache_config=cache_config, - quant_config=quant_config, - **extra_kwargs) - - def _initialize_model(vllm_config: VllmConfig) -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config - lora_config = vllm_config.lora_config - scheduler_config = vllm_config.scheduler_config - cache_config = vllm_config.cache_config model_class, _ = get_model_architecture(model_config) - - return build_model( - model_class, - vllm_config, - model_config.hf_config, - cache_config=cache_config, - quant_config=vllm_config.quant_config, - lora_config=lora_config, - multimodal_config=model_config.multimodal_config, - scheduler_config=scheduler_config, - pooler_config=model_config.pooler_config, - ) + return model_class(vllm_config=vllm_config) class BaseModelLoader(ABC): @@ -486,24 +406,18 @@ class TensorizerLoader(BaseModelLoader): device_config = vllm_config.device_config model_config = vllm_config.model_config - lora_config = vllm_config.lora_config - cache_config = vllm_config.cache_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] - quant_config = vllm_config.quant_config - extra_kwargs = _get_model_initialization_kwargs( - model_class, lora_config, model_config.multimodal_config) - extra_kwargs["quant_config"] = quant_config - extra_kwargs["cache_config"] = cache_config tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class tensorizer_config.hf_config = model_config.hf_config tensorizer_config.dtype = model_config.dtype - model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + model = load_with_tensorizer(tensorizer_config, + vllm_config=vllm_config) return model.eval() def download_model(self, model_config: ModelConfig) -> None: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 437d2772e1f28..c48b287ed181a 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -17,8 +17,6 @@ from vllm.config import ModelConfig, ParallelConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.utils import FlexibleArgumentParser @@ -268,8 +266,7 @@ class TensorizerAgent: in vllm/model_executor/model_loader/weight_utils.py """ - def __init__(self, tensorizer_config: TensorizerConfig, - quant_config: QuantizationConfig, **extra_kwargs): + def __init__(self, tensorizer_config: TensorizerConfig, vllm_config): if tensorizer_error_msg is not None: raise ImportError( "Tensorizer is not installed. Please install tensorizer " @@ -279,11 +276,7 @@ class TensorizerAgent: self.tensorizer_config = tensorizer_config self.tensorizer_args = ( self.tensorizer_config._construct_tensorizer_args()) - self.extra_kwargs = extra_kwargs - if extra_kwargs.get("quant_config") is not None: - self.quant_config = extra_kwargs["quant_config"] - else: - self.quant_config = quant_config + self.vllm_config = vllm_config self.model = self._init_model() def _init_model(self): @@ -293,9 +286,7 @@ class TensorizerAgent: assert self.tensorizer_config.model_class is not None with no_init_or_tensor(): return self.tensorizer_config.model_class( - config=model_args, - quant_config=self.quant_config, - **self.extra_kwargs) + vllm_config=self.vllm_config, ) def _resize_lora_embeddings(self): """Modify LoRA embedding layers to use bigger tensors diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 4fec314a70aa4..997554f7dcccd 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -6,7 +6,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -415,14 +415,16 @@ class ArcticModel(nn.Module): class ArcticForCausalLM(nn.Module, SupportsPP): - def __init__(self, - config: ArcticConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - **kwargs) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config - self.model = ArcticModel(config, cache_config, quant_config) + self.model = ArcticModel(config, + cache_config, + quant_config, + prefix=prefix) self.vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.vocab_size, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index cce182da4820f..8e1dab71b1f39 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,7 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -332,14 +332,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: PretrainedConfig, - position_embedding: str, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", + position_embedding: str = "ROPE", ): super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -439,17 +440,14 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): + config = vllm_config.model_config.hf_config if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", cache_config, quant_config, - lora_config) + super().__init__(vllm_config, prefix, "ROPE") else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", cache_config, quant_config, - lora_config) + super().__init__(vllm_config, prefix, "ALIBI") class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -459,10 +457,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): - super().__init__(config, "ROPE", cache_config, quant_config, - lora_config) + super().__init__(vllm_config, prefix, "ROPE") diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index fd600adceb21c..c6da6a590cf5a 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -25,7 +25,7 @@ from transformers import BartConfig from transformers.utils import logging from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -810,13 +810,13 @@ class BartModel(nn.Module): class BartForConditionalGeneration(nn.Module): base_model_prefix = "model" - def __init__(self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config # currently all existing BART models have `tie_word_embeddings` enabled assert config.tie_word_embeddings self.config = config diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index bfed2929d57d2..2b0f45c5603f5 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -6,7 +6,7 @@ from transformers import BertConfig from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.backends.xformers import XFormersImpl -from vllm.config import CacheConfig, PoolerConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -384,12 +384,14 @@ class BertEmbeddingModel(nn.Module): def __init__( self, - config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - pooler_config: Optional[PoolerConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + pooler_config = vllm_config.model_config.pooler_config self.model = BertModel(config, cache_config, quant_config) self._pooler = Pooler.from_config_with_defaults( pooler_config, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index efd24e7cf40f6..cdc30eda2ab3c 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -8,7 +8,7 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, apply_chunking_to_forward) from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn @@ -483,14 +483,17 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: Blip2Config, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -513,8 +516,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): self.language_model = init_vllm_registered_model( config.text_config, - cache_config, - quant_config, + vllm_config=vllm_config, prefix="language_model") self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index c2440ee75d588..7540bc23efd88 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,7 +24,7 @@ from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -283,11 +283,13 @@ class BloomForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: BloomConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.transformer = BloomModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 58841f177ec22..f79bad6190708 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -9,7 +9,7 @@ from torch import nn from transformers import ChameleonConfig, ChameleonVQVAEConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) @@ -926,12 +926,14 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, def __init__( self, - config: ChameleonConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.model = ChameleonModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index eb9c3e3ae785d..c14f2fcb15063 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -11,7 +11,7 @@ from torch import nn from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) @@ -595,14 +595,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, def __init__( self, - config: ChatGLMConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.lora_config = lora_config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 718f26bed443f..e921fa50b099e 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -28,7 +28,7 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -334,12 +334,14 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: CohereConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config # currently all existing command R models have `tie_word_embeddings` # enabled diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index ae43383155ffc..e3b3164cacde3 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe import FusedMoE @@ -352,11 +352,13 @@ class DbrxForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: DbrxConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config if config.tie_word_embeddings: raise ValueError( diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index 8c9653463858b..3e7005efb39ca 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -22,13 +22,11 @@ # limitations under the License. """Inference-only DeciLM model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Tuple +from typing import Iterable, Tuple import torch -from transformers import LlamaConfig -from vllm.config import CacheConfig, LoRAConfig -from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM @@ -55,17 +53,13 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: + config = vllm_config.model_config.hf_config config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") - super().__init__(config=config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) + super().__init__(vllm_config=vllm_config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 53a1c7cfbfef4..c90d3d250e4c5 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -27,7 +27,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -385,11 +385,13 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = DeepseekModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 95bbf4fb59c6a..0f391d8329a8e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -28,7 +28,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -481,11 +481,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = DeepseekV2Model(config, diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index a87e1c0228627..6bd73d20d340d 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -12,7 +13,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.eagle import EAGLEConfig class EAGLE(nn.Module): @@ -34,14 +34,15 @@ class EAGLE(nn.Module): in the draft checkpoint (using key token_map). Also, the draft config needs to have truncated_vocab_size (=k) as an attribute.""" - def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config self.config = config architectures = getattr(self.config.model, "architectures", []) model_cls, _ = ModelRegistry.resolve_model_cls(architectures) - self.model = model_cls(self.config.model, *args, **kwargs) + self.model = model_cls(vllm_config, prefix) self.fc = nn.Linear(config.model.hidden_size * 2, config.model.hidden_size, bias=getattr(self.config, "eagle_fc_bias", False)) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index a8d591b921cd6..fa6dbfe35b3ad 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -29,7 +29,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -440,12 +440,14 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: ExaoneConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index daf49521637b0..96ae119042277 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,7 +27,7 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -403,11 +403,13 @@ class FalconForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: FalconConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.transformer = FalconModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 184bee5f65671..b0d970d9fb572 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -6,7 +6,7 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -189,11 +189,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module): class Florence2ForConditionalGeneration(nn.Module): - def __init__(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config # TODO(Isotr0py): Add vision backbone self.language_model = Florence2LanguageForConditionalGeneration( diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 653d5d60ea178..cac10f505df67 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -22,14 +22,13 @@ import torch import torch.nn as nn import torch.utils.checkpoint from PIL import Image -from transformers import FuyuConfig, FuyuImageProcessor +from transformers import FuyuImageProcessor from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.linear import ColumnParallelLinear -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -227,12 +226,12 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: FuyuConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 1cc3ea679c553..4e0cbfb9cbf58 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,7 +22,7 @@ from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul @@ -374,13 +374,14 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: GemmaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config # currently all existing Gemma models have `tie_word_embeddings` enabled diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 16e0d6b30713a..773d3b72ec418 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -21,7 +21,7 @@ from transformers import Gemma2Config from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, PoolerConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul @@ -245,12 +245,13 @@ class Gemma2Model(nn.Module): def __init__( self, - config: Gemma2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.embed_tokens = VocabParallelEmbedding( @@ -400,11 +401,13 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: Gemma2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config del lora_config # Unused. super().__init__() self.config = config @@ -470,14 +473,14 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): def __init__( self, - pooler_config: Optional[PoolerConfig] = None, - **kwargs, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() - self.model = Gemma2Model(**kwargs) + self.model = Gemma2Model(vllm_config, prefix) self._pooler = Pooler.from_config_with_defaults( - pooler_config, + vllm_config.model_config.pooler_config, pooling_type=PoolingType.LAST, normalize=True, softmax=False) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 7f81bbff94932..c3fc47db79986 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,7 +24,7 @@ from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -242,11 +242,13 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): def __init__( self, - config: GPT2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.transformer = GPT2Model(config, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 4be8e4199f04d..ea1614d966365 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -260,12 +260,14 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: GPTBigCodeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 834b4aff2e4ba..58cff67c69051 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,7 +23,7 @@ from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -231,11 +231,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: GPTJConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config assert not config.tie_word_embeddings diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 1903156d7efe1..27b2577a8cdca 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,7 +23,7 @@ from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -244,11 +244,13 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: GPTNeoXConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 8a75b9cb1d55d..c3e23b7138e7f 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -28,7 +28,7 @@ from transformers import GraniteConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -372,12 +372,14 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: GraniteConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index b4da986efabe3..73f7c106e3d39 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -335,12 +335,14 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: GraniteMoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 8004367f8dc08..b676171b556a7 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -26,7 +26,7 @@ from transformers import PretrainedConfig as Idefics3Config from transformers import ProcessorMixin as Idefics3ImageProcessor from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger @@ -615,13 +615,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__( self, - config: Idefics3Config, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 8d2d422f9891c..7bb43beff255c 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -11,9 +11,8 @@ from vllm.utils import supports_kw if TYPE_CHECKING: from vllm.attention import AttentionMetadata - from vllm.config import CacheConfig + from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import PoolerOutput - from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -39,10 +38,8 @@ class VllmModel(Protocol[C_co, T_co]): def __init__( self, - config: C_co, - *, - cache_config: Optional["CacheConfig"], - quant_config: Optional["QuantizationConfig"], + vllm_config: "VllmConfig", + prefix: str = "", ) -> None: ... @@ -58,20 +55,7 @@ class VllmModel(Protocol[C_co, T_co]): def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: model_init = model.__init__ - vllm_kws = ("cache_config", "quant_config") - missing_kws = tuple(kw for kw in vllm_kws - if not supports_kw(model_init, kw)) - - if missing_kws and (isinstance(model, type) - and issubclass(model, nn.Module)): - logger.warning( - "The model (%s) is missing " - "vLLM-specific keywords from its initializer: %s", - model, - missing_kws, - ) - - return len(missing_kws) == 0 + return supports_kw(model_init, "vllm_config") def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 7ddb1e2a1ab10..cbedd0c8a0130 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -7,7 +7,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -319,12 +319,13 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = InternLM2Model(config, diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 108fc8382049d..f7bc823574034 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -5,7 +5,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig @@ -161,11 +161,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config super().__init__(config, cache_config, quant_config) self.model = InternLM2VEModel(config, cache_config, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 335b11d293acd..42bccf71273b3 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -16,7 +16,7 @@ from PIL import Image from transformers import PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.quantization import (AWQConfig, @@ -410,13 +410,13 @@ input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) @INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config self.multimodal_config = multimodal_config self._patch_quant_config(config, quant_config) @@ -440,8 +440,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): self.language_model = init_vllm_registered_model( config.text_config, - cache_config, - quant_config, + vllm_config=vllm_config, prefix="language_model") self.mlp1 = self._init_mlp1(config) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 23fdca09493b7..ae3f5b01d5cce 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -26,7 +26,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -288,11 +288,13 @@ class JAISLMHeadModel(nn.Module, SupportsPP): def __init__( self, - config: JAISConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.transformer = JAISModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 9b18a1b68f9d3..72eb1017c2868 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from transformers import JambaConfig from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -350,12 +350,14 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): def __init__( self, - config: JambaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - scheduler_config: Optional[SchedulerConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, \ "Jamba currently does not support prefix caching" diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9e8a403b2f1fc..b765912387e2e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,7 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, PoolerConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -494,15 +494,15 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, prefix: str = "", - pooler_config: Optional[PoolerConfig] = None, ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config @@ -654,12 +654,22 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - pooler_config: Optional[PoolerConfig] = None, - **kwargs, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() - self.model = LlamaModel(**kwargs) + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + pooler_config = vllm_config.model_config.pooler_config + + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config, + prefix=maybe_prefix(prefix, "model")) self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index bdd67b12a06d8..c98462537728a 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -9,7 +9,7 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, PretrainedConfig, SiglipVisionConfig) from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.model_executor.layers.activation import get_act_fn @@ -258,13 +258,13 @@ def init_vision_tower_for_llava( @INPUT_REGISTRY.register_input_processor(input_processor_for_llava) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: LlavaConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config self.multimodal_config = multimodal_config @@ -290,8 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): self.language_model = init_vllm_registered_model( config.text_config, - cache_config, - quant_config, + vllm_config=vllm_config, prefix="language_model") self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 37b8baa8c6be0..f187f8105b96a 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -11,11 +11,10 @@ from transformers.models.llava_next.modeling_llava_next import ( from typing_extensions import NotRequired from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -282,13 +281,12 @@ def input_processor_for_llava_next(ctx: InputContext, class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: LlavaNextConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - pooler_config: Optional[PoolerConfig] = None) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + pooler_config = vllm_config.model_config.pooler_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -308,8 +306,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, self.language_model = init_vllm_registered_model( config.text_config, - cache_config, - quant_config, + vllm_config=vllm_config, prefix="language_model") # The same model class supports both language generation and embedding diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 69bfc80a4372c..eceb0c0ab52df 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -10,11 +10,10 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, SiglipVisionConfig) from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -254,12 +253,11 @@ class LlavaNextMultiModalProjector(nn.Module): class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: LlavaNextVideoConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -277,8 +275,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( config.text_config, - cache_config, - quant_config, + vllm_config=vllm_config, prefix="language_model") self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index ad5d551ee0834..64d373ce91509 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -14,11 +14,10 @@ from transformers.models.llava_onevision.modeling_llava_onevision import ( from typing_extensions import NotRequired from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -405,12 +404,11 @@ class LlavaOnevisionMultiModalProjector(nn.Module): class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: LlavaOnevisionConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -424,8 +422,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( config.text_config, - cache_config, - quant_config, + vllm_config=vllm_config, prefix="language_model") self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 91161957642f9..49e43f8cc683c 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from torch import nn from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -132,12 +132,14 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): def __init__( self, - config: MambaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - scheduler_config: Optional[SchedulerConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, \ "Mamba does not support prefix caching" diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 619a5cd00d6b6..4cb1b4a929b9f 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -3,13 +3,13 @@ from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn +from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.transformers_utils.configs.medusa import MedusaConfig class ResidualBlock(nn.Module): @@ -44,7 +44,8 @@ class Medusa(nn.Module): in the draft checkpoint (using key token_map). Also, the draft config needs to have truncated_vocab_size (=k) as an attribute.""" - def __init__(self, config: MedusaConfig, **_) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config super().__init__() self.config = config self.blocks = nn.ModuleList([ diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 7704431a4d90a..559d9c4dd35bf 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -29,7 +29,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -463,12 +463,14 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f8006095e2eb2..9458204c5a038 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -34,7 +34,7 @@ from transformers import PretrainedConfig from typing_extensions import NotRequired from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig +from vllm.config import CacheConfig, VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -385,11 +385,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__( self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot @@ -701,12 +703,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel): def __init__( self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): - super().__init__(config, multimodal_config, cache_config, quant_config) + super().__init__(vllm_config) assert self.version == (2, 0) def init_llm( @@ -867,13 +867,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): def __init__( self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): - super().__init__(config, multimodal_config, cache_config, quant_config) + super().__init__(vllm_config) assert self.version == (2, 5) def init_llm( @@ -1017,12 +1014,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): def __init__( self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ): - super().__init__(config, multimodal_config, cache_config, quant_config) + super().__init__(vllm_config) assert self.version == (2, 6) def init_llm( @@ -1141,12 +1136,8 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): embedding_modules = {} embedding_padding_modules = [] - def __new__(cls, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None): + def __new__(cls, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config if not hasattr(config, "version"): if config.hidden_size == 2304 and config.query_num == 64: version = (2, 0) @@ -1160,5 +1151,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): if instance_class is None: raise ValueError( "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") - return instance_class(config, multimodal_config, cache_config, - quant_config) + return instance_class(vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index f5c28e7d74811..91ec3228c0d48 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -28,7 +28,7 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -334,13 +334,14 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 007c4e2eabc90..aeac326776392 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -29,7 +29,7 @@ from torch import nn from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -352,11 +352,13 @@ class MixtralForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = MixtralModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 18e38daadc93a..14aa515570f38 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import ( import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, InputContext, TokenInputs, token_inputs) @@ -1108,12 +1108,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): "up_proj": ("gate_up_proj", 1), } - def __init__(self, - config: config_mllama.MllamaConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 5f2f61cc610b3..cd462c4d0495e 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -3,8 +3,7 @@ import re from array import array from dataclasses import dataclass from functools import lru_cache, partial -from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict, - Union) +from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union import torch from einops import rearrange @@ -16,7 +15,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.attention.selector import _Backend from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -1027,13 +1026,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__( self, - config: PretrainedConfig, - multimodal_config: Optional[MultiModalConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[Mapping[str, Any]] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index b3977812cb273..672c8e9c22260 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,7 +7,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -269,11 +269,13 @@ class MPTForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: MPTConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config assert config.tie_word_embeddings self.quant_config = quant_config diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 8d128a42b14b8..5991cce642981 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -27,7 +27,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -403,13 +403,14 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: NemotronConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config assert isinstance(config, NemotronConfig) self.config = config diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 545d86eebb5ec..6905f8521a8c3 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -28,7 +28,7 @@ from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -291,11 +291,15 @@ class OlmoForCausalLM(nn.Module, SupportsPP): Extremely barebones HF model wrapper. """ - def __init__(self, - config: OlmoConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.model = OlmoModel(config, cache_config, quant_config) if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index de30b5270e7e8..8fa90d17003af 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -18,7 +18,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -311,11 +311,13 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = OlmoeModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index a453376d02552..d378956b68cfc 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,7 +24,7 @@ from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -344,11 +344,13 @@ class OPTForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: OPTConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", - ): + ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config super().__init__() self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index d6ec1fb602f05..b400d4e3f5228 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,7 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -270,11 +270,13 @@ class OrionForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = OrionModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 4b6061e113cb2..69b7fe9d56847 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -6,13 +6,11 @@ from torch import nn from transformers import PaliGemmaConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer @@ -21,7 +19,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import AutoWeightsLoader, merge_multimodal_embeddings +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + merge_multimodal_embeddings) logger = init_logger(__name__) @@ -132,13 +131,15 @@ class PaliGemmaMultiModalProjector(nn.Module): class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: PaliGemmaConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -150,10 +151,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, projection_dim=config.vision_config.projection_dim) self.quant_config = quant_config - self.language_model = GemmaForCausalLM(config.text_config, - cache_config, - quant_config, - prefix="language_model") + config.text_config.architectures = ["GemmaForCausalLM"] + self.language_model = init_vllm_registered_model( + config.text_config, + vllm_config=vllm_config, + prefix="language_model") logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 11e7c8abd4888..a86e2c1b4e4a1 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -27,7 +27,7 @@ from transformers import PersimmonConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -265,11 +265,15 @@ class PersimmonModel(nn.Module): class PersimmonForCausalLM(nn.Module, SupportsPP): - def __init__(self, - config: PersimmonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.vocab_size = config.vocab_size self.model = PersimmonModel(config, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 4dae6e323654b..fef921528b042 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,7 +42,7 @@ from transformers import PhiConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -279,13 +279,14 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: PhiConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ): + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config # lm_head use bias, cannot share word embeddings assert not config.tie_word_embeddings diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 92bf0e61448e5..de1b09eba6c6d 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -6,7 +6,7 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -365,12 +365,13 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ): + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = Phi3SmallModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a84d6b317b479..65131d61673a3 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -25,8 +25,7 @@ from PIL import Image from transformers import CLIPVisionConfig, PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig, - PoolerConfig) +from vllm.config import ModelConfig, VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger @@ -526,14 +525,16 @@ def input_processor_for_phi3v(ctx: InputContext, @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - pooler_config: Optional[PoolerConfig] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + pooler_config = vllm_config.model_config.pooler_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.image_token_id = _IMAGE_TOKEN_ID @@ -552,8 +553,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # The prefix is empty intentionally because default prefix of # LlamaForCausalLM is "model" - self.language_model = LlamaForCausalLM(config, cache_config, - quant_config) + self.language_model = LlamaForCausalLM(vllm_config=vllm_config, + prefix="") # The same model class supports both language generation and embedding # because the architecture name is the same diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 19e2621ead996..17d00c0ede2b2 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -531,13 +531,14 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: PhiMoEConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index de935fc420472..93919c9c051c0 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -9,14 +9,14 @@ import torch.nn as nn import torch.nn.functional as F from mistral_common.protocol.instruct.messages import ImageChunk from PIL import Image -from transformers import PixtralVisionConfig, PretrainedConfig +from transformers import PixtralVisionConfig from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens) from transformers.models.pixtral.modeling_pixtral import ( PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, MultiModalConfig +from vllm.config import ModelConfig, VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_and_mul_fn @@ -152,13 +152,14 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -174,8 +175,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, # init MistralForCausalLM self.language_model = init_vllm_registered_model( config.text_config, - cache_config, - quant_config, + vllm_config=vllm_config, prefix="language_model") self.vision_encoder = VisionTransformer(self.vision_args) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 1db7e2ba1cc12..d3f10ee7c85ca 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -20,7 +20,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) @@ -867,13 +867,14 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__( self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ): + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config @@ -1064,17 +1065,13 @@ class QWenLMHeadModel(QWenBaseModel, SupportsLoRA): def __new__( cls, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ): + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + config = vllm_config.model_config.hf_config # Initialize VL if hasattr(config, "visual"): - return QWenVL(config, multimodal_config, cache_config, - quant_config, lora_config) + return QWenVL(vllm_config) # Initialize LLM else: - return QWenLLM(config, multimodal_config, cache_config, - quant_config, lora_config) + return QWenLLM(vllm_config) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 1e99c1b13b31f..b0156a25ca5cf 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -29,7 +29,7 @@ from transformers import Qwen2Config from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -405,12 +405,14 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None and hasattr(config, "max_window_layers")): @@ -423,8 +425,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): config.num_hidden_layers, )) - super().__init__() - self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 18cf45b3939f7..1057720e8c308 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -26,16 +26,14 @@ import librosa import numpy as np import torch import torch.nn as nn -from transformers import Qwen2AudioConfig, Qwen2AudioEncoder +from transformers import Qwen2AudioEncoder from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( @@ -266,13 +264,16 @@ def input_mapper_for_qwen2_audio( class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: Qwen2AudioConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index b9e3b74c477e2..25ecf76e35f22 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -8,14 +8,11 @@ from typing import Iterable, List, Optional, Tuple import torch from torch import nn -from transformers import Qwen2Config from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, PoolerConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput @@ -48,12 +45,15 @@ class Qwen2ForSequenceClassification(nn.Module): def __init__( self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - pooler_config: Optional[PoolerConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + pooler_config = vllm_config.model_config.pooler_config # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None and hasattr(config, "max_window_layers")): @@ -66,8 +66,6 @@ class Qwen2ForSequenceClassification(nn.Module): config.num_hidden_layers, )) - super().__init__() - self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index c8c48c0894c36..b1177f9c59063 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,7 +30,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -379,11 +379,13 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = Qwen2MoeModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 0fbf305da8b94..1f9411241bdd6 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -7,14 +7,12 @@ from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn -from transformers import Qwen2Config from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, PoolerConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput @@ -59,12 +57,15 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): def __init__( self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - pooler_config: Optional[PoolerConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + pooler_config = vllm_config.model_config.pooler_config # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None and hasattr(config, "max_window_layers")): @@ -77,8 +78,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): config.num_hidden_layers, )) - super().__init__() - self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 8dd75c9ee7e7b..ab80c1494d067 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -40,7 +40,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( from vllm.attention import AttentionMetadata from vllm.attention.selector import _Backend -from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, @@ -966,15 +966,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, embedding_modules = {} embedding_padding_modules = [] - def __init__(self, - config: Qwen2VLConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None) -> None: - + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config assert not cache_config.enable_prefix_caching, \ "Qwen2-VL currently does not support prefix caching" diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 931e48a44f631..ffabac8292dbd 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -29,7 +29,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -411,13 +411,14 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 4cb55506bb237..975d316977c37 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -25,7 +25,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -247,11 +247,13 @@ class StablelmForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = StableLMEpochModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 0b0e3f21065b4..ae61aa4e248a5 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,7 +25,7 @@ from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -245,11 +245,15 @@ class Starcoder2Model(nn.Module): class Starcoder2ForCausalLM(nn.Module, SupportsPP): - def __init__(self, - config: Starcoder2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.config = config self.model = Starcoder2Model(config, cache_config, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 6b7a638585ad9..d47f0091e0f9f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -15,12 +15,11 @@ from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper.modeling_whisper import WhisperEncoder from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -340,12 +339,14 @@ class ModifiedWhisperEncoder(WhisperEncoder): @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, - config: UltravoxConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional["QuantizationConfig"] = None): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multi_modal_config = multimodal_config assert self.multi_modal_config @@ -361,10 +362,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, - cache_config, - quant_config, - prefix="language_model") + config.text_config, vllm_config, prefix="language_model") if config.text_model_id is not None: self.secondary_weights.append( DefaultModelLoader.Source(model_or_path=config.text_model_id, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index fee97e8922a76..60eeceb18bcf0 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -11,11 +11,8 @@ from transformers import PretrainedConfig import vllm.envs as envs from vllm.attention.selector import (_Backend, backend_name_to_enum, get_global_forced_attn_backend) -from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors @@ -236,12 +233,7 @@ class AutoWeightsLoader: def init_vllm_registered_model( hf_config: PretrainedConfig, - cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], - *, - lora_config: Optional[LoRAConfig] = None, - multimodal_config: Optional[MultiModalConfig] = None, - scheduler_config: Optional[SchedulerConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: """ @@ -249,18 +241,11 @@ def init_vllm_registered_model( based on the arguments passed to the outer vLLM model. """ model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) + import copy + copied_config = copy.deepcopy(vllm_config) + copied_config.model_config.hf_config = hf_config - return build_model( - model_class, - None, - hf_config, - cache_config, - quant_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - scheduler_config=scheduler_config, - prefix=prefix, - ) + return model_class(vllm_config=copied_config, prefix=prefix) @overload diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 1d08b382b0b00..7afb99176077b 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -27,7 +27,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -225,13 +225,14 @@ class XverseModel(nn.Module): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * @@ -316,13 +317,16 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.lora_config = lora_config diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 3336569f59467..8373e11cfff9f 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -61,15 +61,3 @@ def set_compilation_config(config: Optional[CompilationConfig]): def get_compilation_config() -> Optional[CompilationConfig]: return _compilation_config - - -_vllm_config: Optional[VllmConfig] = None - - -def set_vllm_config(config: Optional[VllmConfig]): - global _vllm_config - _vllm_config = config - - -def get_vllm_config() -> Optional[VllmConfig]: - return _vllm_config