[5/N] pass the whole config to model (#9983)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-08 22:17:28 -08:00 committed by GitHub
parent 49d2a41a86
commit 1a95f10ee7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
75 changed files with 583 additions and 654 deletions

View File

@ -9,8 +9,7 @@ import math
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple, from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
Type, cast)
import gguf import gguf
import huggingface_hub import huggingface_hub
@ -18,20 +17,17 @@ import numpy as np
import torch import torch
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig, from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
ModelConfig, MultiModalConfig, ParallelConfig, VllmConfig)
PoolerConfig, SchedulerConfig, VllmConfig)
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ReplicatedLinear, from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator) serialize_vllm_model, tensorizer_weights_iterator)
@ -43,8 +39,6 @@ from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names, gguf_quant_weights_iterator, get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator) safetensors_weights_iterator)
from vllm.model_executor.models import (has_inner_state, supports_lora,
supports_multimodal)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -94,85 +88,11 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__) logger = init_logger(__name__)
def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}
if supports_lora(model_class):
# lora_config=None is used to disable LoRA
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
if supports_multimodal(model_class):
assert multimodal_config is not None
extra_kwargs["multimodal_config"] = multimodal_config
if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config
if pooler_config:
extra_kwargs["pooler_config"] = pooler_config
return extra_kwargs
def build_model(model_class: Type[nn.Module],
vllm_config: Optional[VllmConfig],
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
*,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig],
prefix: Optional[str] = None,
pooler_config: Optional[PoolerConfig] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config,
pooler_config)
if prefix:
extra_kwargs["prefix"] = prefix
# TODO: unify all the module initialization code
# to only take the `VllmConfig` object as input
from vllm.plugins import set_vllm_config
set_vllm_config(vllm_config)
return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)
def _initialize_model(vllm_config: VllmConfig) -> nn.Module: def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_config = vllm_config.model_config model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
return model_class(vllm_config=vllm_config)
return build_model(
model_class,
vllm_config,
model_config.hf_config,
cache_config=cache_config,
quant_config=vllm_config.quant_config,
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
pooler_config=model_config.pooler_config,
)
class BaseModelLoader(ABC): class BaseModelLoader(ABC):
@ -486,24 +406,18 @@ class TensorizerLoader(BaseModelLoader):
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
cache_config = vllm_config.cache_config
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
quant_config = vllm_config.quant_config
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, model_config.multimodal_config)
extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config
tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config, **extra_kwargs) model = load_with_tensorizer(tensorizer_config,
vllm_config=vllm_config)
return model.eval() return model.eval()
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:

View File

@ -17,8 +17,6 @@ from vllm.config import ModelConfig, ParallelConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -268,8 +266,7 @@ class TensorizerAgent:
in vllm/model_executor/model_loader/weight_utils.py in vllm/model_executor/model_loader/weight_utils.py
""" """
def __init__(self, tensorizer_config: TensorizerConfig, def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
quant_config: QuantizationConfig, **extra_kwargs):
if tensorizer_error_msg is not None: if tensorizer_error_msg is not None:
raise ImportError( raise ImportError(
"Tensorizer is not installed. Please install tensorizer " "Tensorizer is not installed. Please install tensorizer "
@ -279,11 +276,7 @@ class TensorizerAgent:
self.tensorizer_config = tensorizer_config self.tensorizer_config = tensorizer_config
self.tensorizer_args = ( self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args()) self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs self.vllm_config = vllm_config
if extra_kwargs.get("quant_config") is not None:
self.quant_config = extra_kwargs["quant_config"]
else:
self.quant_config = quant_config
self.model = self._init_model() self.model = self._init_model()
def _init_model(self): def _init_model(self):
@ -293,9 +286,7 @@ class TensorizerAgent:
assert self.tensorizer_config.model_class is not None assert self.tensorizer_config.model_class is not None
with no_init_or_tensor(): with no_init_or_tensor():
return self.tensorizer_config.model_class( return self.tensorizer_config.model_class(
config=model_args, vllm_config=self.vllm_config, )
quant_config=self.quant_config,
**self.extra_kwargs)
def _resize_lora_embeddings(self): def _resize_lora_embeddings(self):
"""Modify LoRA embedding layers to use bigger tensors """Modify LoRA embedding layers to use bigger tensors

View File

@ -6,7 +6,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@ -415,14 +415,16 @@ class ArcticModel(nn.Module):
class ArcticForCausalLM(nn.Module, SupportsPP): class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
**kwargs) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = ArcticModel(config, cache_config, quant_config) self.model = ArcticModel(config,
cache_config,
quant_config,
prefix=prefix)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.vocab_size, self.vocab_size,

View File

@ -26,7 +26,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -332,14 +332,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
position_embedding: str, prefix: str = "",
cache_config: Optional[CacheConfig] = None, position_embedding: str = "ROPE",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
@ -439,17 +440,14 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
config = vllm_config.model_config.hf_config
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", cache_config, quant_config, super().__init__(vllm_config, prefix, "ROPE")
lora_config)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", cache_config, quant_config, super().__init__(vllm_config, prefix, "ALIBI")
lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
@ -459,10 +457,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__(config, "ROPE", cache_config, quant_config, super().__init__(vllm_config, prefix, "ROPE")
lora_config)

View File

@ -25,7 +25,7 @@ from transformers import BartConfig
from transformers.utils import logging from transformers.utils import logging
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -810,13 +810,13 @@ class BartModel(nn.Module):
class BartForConditionalGeneration(nn.Module): class BartForConditionalGeneration(nn.Module):
base_model_prefix = "model" base_model_prefix = "model"
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# currently all existing BART models have `tie_word_embeddings` enabled # currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.config = config self.config = config

View File

@ -6,7 +6,7 @@ from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersImpl from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import CacheConfig, PoolerConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -384,12 +384,14 @@ class BertEmbeddingModel(nn.Module):
def __init__( def __init__(
self, self,
config: BertConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(config, cache_config, quant_config) self.model = BertModel(config, cache_config, quant_config)
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,

View File

@ -8,7 +8,7 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
apply_chunking_to_forward) apply_chunking_to_forward)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -483,14 +483,17 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(
config: Blip2Config, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -513,8 +516,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (

View File

@ -24,7 +24,7 @@ from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -283,11 +283,13 @@ class BloomForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: BloomConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config) self.transformer = BloomModel(config, cache_config, quant_config)

View File

@ -9,7 +9,7 @@ from torch import nn
from transformers import ChameleonConfig, ChameleonVQVAEConfig from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
@ -926,12 +926,14 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__( def __init__(
self, self,
config: ChameleonConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.model = ChameleonModel(config, cache_config, quant_config) self.model = ChameleonModel(config, cache_config, quant_config)

View File

@ -11,7 +11,7 @@ from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
@ -595,14 +595,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
def __init__( def __init__(
self, self,
config: ChatGLMConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config

View File

@ -28,7 +28,7 @@ from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -334,12 +334,14 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: CohereConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
# currently all existing command R models have `tie_word_embeddings` # currently all existing command R models have `tie_word_embeddings`
# enabled # enabled

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -352,11 +352,13 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: DbrxConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
if config.tie_word_embeddings: if config.tie_word_embeddings:
raise ValueError( raise ValueError(

View File

@ -22,13 +22,11 @@
# limitations under the License. # limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights.""" """Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Tuple from typing import Iterable, Tuple
import torch import torch
from transformers import LlamaConfig
from vllm.config import CacheConfig, LoRAConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
@ -55,17 +53,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__( def __init__(
self, self,
config: LlamaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
config.num_key_value_heads = max(config.num_key_value_heads_per_layer) config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, super().__init__(vllm_config=vllm_config)
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@ -27,7 +27,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@ -385,11 +385,13 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekModel(config, cache_config, quant_config) self.model = DeepseekModel(config, cache_config, quant_config)

View File

@ -28,7 +28,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@ -481,11 +481,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekV2Model(config, self.model = DeepseekV2Model(config,

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -12,7 +13,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.eagle import EAGLEConfig
class EAGLE(nn.Module): class EAGLE(nn.Module):
@ -34,14 +34,15 @@ class EAGLE(nn.Module):
in the draft checkpoint (using key token_map). Also, the draft config in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute.""" needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
self.config = config self.config = config
architectures = getattr(self.config.model, "architectures", []) architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures) model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(self.config.model, *args, **kwargs) self.model = model_cls(vllm_config, prefix)
self.fc = nn.Linear(config.model.hidden_size * 2, self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size, config.model.hidden_size,
bias=getattr(self.config, "eagle_fc_bias", False)) bias=getattr(self.config, "eagle_fc_bias", False))

View File

@ -29,7 +29,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -440,12 +440,14 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: ExaoneConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -27,7 +27,7 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@ -403,11 +403,13 @@ class FalconForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: FalconConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = FalconModel(config, cache_config, quant_config) self.transformer = FalconModel(config, cache_config, quant_config)

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
@ -189,11 +189,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
class Florence2ForConditionalGeneration(nn.Module): class Florence2ForConditionalGeneration(nn.Module):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO(Isotr0py): Add vision backbone # TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration( self.language_model = Florence2LanguageForConditionalGeneration(

View File

@ -22,14 +22,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from PIL import Image from PIL import Image
from transformers import FuyuConfig, FuyuImageProcessor from transformers import FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -227,12 +226,12 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: FuyuConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config

View File

@ -22,7 +22,7 @@ from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
@ -374,13 +374,14 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: GemmaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled # currently all existing Gemma models have `tie_word_embeddings` enabled

View File

@ -21,7 +21,7 @@ from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
@ -245,12 +245,13 @@ class Gemma2Model(nn.Module):
def __init__( def __init__(
self, self,
config: Gemma2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
@ -400,11 +401,13 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: Gemma2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused. del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
@ -470,14 +473,14 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
pooler_config: Optional[PoolerConfig] = None, vllm_config: VllmConfig,
**kwargs, prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = Gemma2Model(**kwargs) self.model = Gemma2Model(vllm_config, prefix)
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,
normalize=True, normalize=True,
softmax=False) softmax=False)

View File

@ -24,7 +24,7 @@ from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size) get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -242,11 +242,13 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: GPT2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, self.transformer = GPT2Model(config,

View File

@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -260,12 +260,14 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -23,7 +23,7 @@ from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -231,11 +231,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: GPTJConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
assert not config.tie_word_embeddings assert not config.tie_word_embeddings

View File

@ -23,7 +23,7 @@ from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -244,11 +244,13 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)

View File

@ -28,7 +28,7 @@ from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -372,12 +372,14 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: GraniteConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -335,12 +335,14 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: GraniteMoeConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -26,7 +26,7 @@ from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor from transformers import ProcessorMixin as Idefics3ImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -615,13 +615,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__( def __init__(
self, self,
config: Idefics3Config, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config

View File

@ -11,9 +11,8 @@ from vllm.utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -39,10 +38,8 @@ class VllmModel(Protocol[C_co, T_co]):
def __init__( def __init__(
self, self,
config: C_co, vllm_config: "VllmConfig",
*, prefix: str = "",
cache_config: Optional["CacheConfig"],
quant_config: Optional["QuantizationConfig"],
) -> None: ) -> None:
... ...
@ -58,20 +55,7 @@ class VllmModel(Protocol[C_co, T_co]):
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
model_init = model.__init__ model_init = model.__init__
vllm_kws = ("cache_config", "quant_config") return supports_kw(model_init, "vllm_config")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_init, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:

View File

@ -7,7 +7,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
@ -319,12 +319,13 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = InternLM2Model(config, self.model = InternLM2Model(config,

View File

@ -5,7 +5,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -161,11 +161,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__(config, cache_config, quant_config) super().__init__(config, cache_config, quant_config)
self.model = InternLM2VEModel(config, self.model = InternLM2VEModel(config,
cache_config, cache_config,

View File

@ -16,7 +16,7 @@ from PIL import Image
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig, from vllm.model_executor.layers.quantization import (AWQConfig,
@ -410,13 +410,13 @@ input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) @INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config) self._patch_quant_config(config, quant_config)
@ -440,8 +440,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.mlp1 = self._init_mlp1(config) self.mlp1 = self._init_mlp1(config)

View File

@ -26,7 +26,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -288,11 +288,13 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: JAISConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config) self.transformer = JAISModel(config, cache_config, quant_config)

View File

@ -7,7 +7,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -350,12 +350,14 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
def __init__( def __init__(
self, self,
config: JambaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Jamba currently does not support prefix caching" "Jamba currently does not support prefix caching"

View File

@ -28,7 +28,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -494,15 +494,15 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: LlamaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
@ -654,12 +654,22 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
pooler_config: Optional[PoolerConfig] = None, vllm_config: VllmConfig,
**kwargs, prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = LlamaModel(**kwargs) config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,

View File

@ -9,7 +9,7 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
PretrainedConfig, SiglipVisionConfig) PretrainedConfig, SiglipVisionConfig)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext) InputContext)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -258,13 +258,13 @@ def init_vision_tower_for_llava(
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: LlavaConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -290,8 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (

View File

@ -11,11 +11,10 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext) InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -282,13 +281,12 @@ def input_processor_for_llava_next(ctx: InputContext,
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: LlavaNextConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -308,8 +306,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding

View File

@ -10,11 +10,10 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
SiglipVisionConfig) SiglipVisionConfig)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -254,12 +253,11 @@ class LlavaNextMultiModalProjector(nn.Module):
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: LlavaNextVideoConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -277,8 +275,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (

View File

@ -14,11 +14,10 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -405,12 +404,11 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: LlavaOnevisionConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -424,8 +422,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))

View File

@ -6,7 +6,7 @@ from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -132,12 +132,14 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__( def __init__(
self, self,
config: MambaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching" "Mamba does not support prefix caching"

View File

@ -3,13 +3,13 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.transformers_utils.configs.medusa import MedusaConfig
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
@ -44,7 +44,8 @@ class Medusa(nn.Module):
in the draft checkpoint (using key token_map). Also, the draft config in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute.""" needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, config: MedusaConfig, **_) -> None: def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
super().__init__() super().__init__()
self.config = config self.config = config
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([

View File

@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@ -463,12 +463,14 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -34,7 +34,7 @@ from transformers import PretrainedConfig
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -385,11 +385,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
): ):
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__() super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but # All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
@ -701,12 +703,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__(config, multimodal_config, cache_config, quant_config) super().__init__(vllm_config)
assert self.version == (2, 0) assert self.version == (2, 0)
def init_llm( def init_llm(
@ -867,13 +867,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__(config, multimodal_config, cache_config, quant_config) super().__init__(vllm_config)
assert self.version == (2, 5) assert self.version == (2, 5)
def init_llm( def init_llm(
@ -1017,12 +1014,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__(config, multimodal_config, cache_config, quant_config) super().__init__(vllm_config)
assert self.version == (2, 6) assert self.version == (2, 6)
def init_llm( def init_llm(
@ -1141,12 +1136,8 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __new__(cls, def __new__(cls, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig, config = vllm_config.model_config.hf_config
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
if not hasattr(config, "version"): if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64: if config.hidden_size == 2304 and config.query_num == 64:
version = (2, 0) version = (2, 0)
@ -1160,5 +1151,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
if instance_class is None: if instance_class is None:
raise ValueError( raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(config, multimodal_config, cache_config, return instance_class(vllm_config, prefix=prefix)
quant_config)

View File

@ -28,7 +28,7 @@ from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -334,13 +334,14 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: MixtralConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -29,7 +29,7 @@ from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@ -352,11 +352,13 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: MixtralConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel(config, cache_config, quant_config) self.model = MixtralModel(config, cache_config, quant_config)

View File

@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs) InputContext, TokenInputs, token_inputs)
@ -1108,12 +1108,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
def __init__(self, def __init__(
config: config_mllama.MllamaConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None): ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles self.max_num_tiles = config.vision_config.max_num_tiles

View File

@ -3,8 +3,7 @@ import re
from array import array from array import array
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict, from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
Union)
import torch import torch
from einops import rearrange from einops import rearrange
@ -16,7 +15,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import _Backend from vllm.attention.selector import _Backend
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
@ -1027,13 +1026,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: Optional[MultiModalConfig] = None, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[Mapping[str, Any]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -269,11 +269,13 @@ class MPTForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: MPTConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.quant_config = quant_config self.quant_config = quant_config

View File

@ -27,7 +27,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -403,13 +403,14 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: NemotronConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
assert isinstance(config, NemotronConfig) assert isinstance(config, NemotronConfig)
self.config = config self.config = config

View File

@ -28,7 +28,7 @@ from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -291,11 +291,15 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
Extremely barebones HF model wrapper. Extremely barebones HF model wrapper.
""" """
def __init__(self, def __init__(
config: OlmoConfig, self,
cache_config: Optional[CacheConfig] = None, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None): prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = OlmoModel(config, cache_config, quant_config) self.model = OlmoModel(config, cache_config, quant_config)
if config.tie_word_embeddings: if config.tie_word_embeddings:

View File

@ -18,7 +18,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -311,11 +311,13 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OlmoeModel(config, cache_config, quant_config) self.model = OlmoeModel(config, cache_config, quant_config)

View File

@ -24,7 +24,7 @@ from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -344,11 +344,13 @@ class OPTForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: OPTConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config

View File

@ -11,7 +11,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -270,11 +270,13 @@ class OrionForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OrionModel(config, cache_config, quant_config) self.model = OrionModel(config, cache_config, quant_config)

View File

@ -6,13 +6,11 @@ from torch import nn
from transformers import PaliGemmaConfig from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
@ -21,7 +19,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import AutoWeightsLoader, merge_multimodal_embeddings from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -132,13 +131,15 @@ class PaliGemmaMultiModalProjector(nn.Module):
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(
config: PaliGemmaConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -150,10 +151,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
projection_dim=config.vision_config.projection_dim) projection_dim=config.vision_config.projection_dim)
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = GemmaForCausalLM(config.text_config, config.text_config.architectures = ["GemmaForCausalLM"]
cache_config, self.language_model = init_vllm_registered_model(
quant_config, config.text_config,
prefix="language_model") vllm_config=vllm_config,
prefix="language_model")
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale self.language_model.logits_processor.scale *= logit_scale

View File

@ -27,7 +27,7 @@ from transformers import PersimmonConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -265,11 +265,15 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module, SupportsPP): class PersimmonForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(
config: PersimmonConfig, self,
cache_config: Optional[CacheConfig] = None, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None): prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = PersimmonModel(config, self.model = PersimmonModel(config,

View File

@ -42,7 +42,7 @@ from transformers import PhiConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -279,13 +279,14 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PhiConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
lora_config: Optional[LoRAConfig] = None,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
# lm_head use bias, cannot share word embeddings # lm_head use bias, cannot share word embeddings
assert not config.tie_word_embeddings assert not config.tie_word_embeddings

View File

@ -6,7 +6,7 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -365,12 +365,13 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
lora_config: Optional[LoRAConfig] = None,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Phi3SmallModel(config, cache_config, quant_config) self.model = Phi3SmallModel(config, cache_config, quant_config)

View File

@ -25,8 +25,7 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig, from vllm.config import ModelConfig, VllmConfig
PoolerConfig)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -526,14 +525,16 @@ def input_processor_for_phi3v(ctx: InputContext,
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(
config: PretrainedConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID self.image_token_id = _IMAGE_TOKEN_ID
@ -552,8 +553,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# The prefix is empty intentionally because default prefix of # The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model" # LlamaForCausalLM is "model"
self.language_model = LlamaForCausalLM(config, cache_config, self.language_model = LlamaForCausalLM(vllm_config=vllm_config,
quant_config) prefix="")
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding
# because the architecture name is the same # because the architecture name is the same

View File

@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -531,13 +531,14 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PhiMoEConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -9,14 +9,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk from mistral_common.protocol.instruct.messages import ImageChunk
from PIL import Image from PIL import Image
from transformers import PixtralVisionConfig, PretrainedConfig from transformers import PixtralVisionConfig
from transformers.models.pixtral.image_processing_pixtral import ( from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens) _num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import ( from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
@ -152,13 +152,14 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(
config: PretrainedConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -174,8 +175,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# init MistralForCausalLM # init MistralForCausalLM
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.vision_encoder = VisionTransformer(self.vision_args) self.vision_encoder = VisionTransformer(self.vision_args)

View File

@ -20,7 +20,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
@ -867,13 +867,14 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, ) -> None:
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.quant_config = quant_config self.quant_config = quant_config
@ -1064,17 +1065,13 @@ class QWenLMHeadModel(QWenBaseModel, SupportsLoRA):
def __new__( def __new__(
cls, cls,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, ) -> None:
quant_config: Optional[QuantizationConfig] = None, config = vllm_config.model_config.hf_config
lora_config: Optional[LoRAConfig] = None,
):
# Initialize VL # Initialize VL
if hasattr(config, "visual"): if hasattr(config, "visual"):
return QWenVL(config, multimodal_config, cache_config, return QWenVL(vllm_config)
quant_config, lora_config)
# Initialize LLM # Initialize LLM
else: else:
return QWenLLM(config, multimodal_config, cache_config, return QWenLLM(vllm_config)
quant_config, lora_config)

View File

@ -29,7 +29,7 @@ from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -405,12 +405,14 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: Qwen2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
@ -423,8 +425,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config.num_hidden_layers, config.num_hidden_layers,
)) ))
super().__init__()
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -26,16 +26,14 @@ import librosa
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import Qwen2AudioConfig, Qwen2AudioEncoder from transformers import Qwen2AudioEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
@ -266,13 +264,16 @@ def input_mapper_for_qwen2_audio(
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(
config: Qwen2AudioConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config

View File

@ -8,14 +8,11 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
@ -48,12 +45,15 @@ class Qwen2ForSequenceClassification(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
@ -66,8 +66,6 @@ class Qwen2ForSequenceClassification(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
)) ))
super().__init__()
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -30,7 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@ -379,11 +379,13 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2MoeModel(config, cache_config, quant_config) self.model = Qwen2MoeModel(config, cache_config, quant_config)

View File

@ -7,14 +7,12 @@ from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
@ -59,12 +57,15 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: Qwen2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
@ -77,8 +78,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
config.num_hidden_layers, config.num_hidden_layers,
)) ))
super().__init__()
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -40,7 +40,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.attention.selector import _Backend from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
@ -966,15 +966,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __init__(self, def __init__(
config: Qwen2VLConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
lora_config: Optional[LoRAConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching" "Qwen2-VL currently does not support prefix caching"

View File

@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -411,13 +411,14 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -25,7 +25,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -247,11 +247,13 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = StableLMEpochModel(config, cache_config, quant_config) self.model = StableLMEpochModel(config, cache_config, quant_config)

View File

@ -25,7 +25,7 @@ from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -245,11 +245,15 @@ class Starcoder2Model(nn.Module):
class Starcoder2ForCausalLM(nn.Module, SupportsPP): class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(
config: Starcoder2Config, self,
cache_config: Optional[CacheConfig] = None, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None): prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = Starcoder2Model(config, self.model = Starcoder2Model(config,
cache_config, cache_config,

View File

@ -15,12 +15,11 @@ from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -340,12 +339,14 @@ class ModifiedWhisperEncoder(WhisperEncoder):
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(
config: UltravoxConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional["QuantizationConfig"] = None): ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multi_modal_config = multimodal_config self.multi_modal_config = multimodal_config
assert self.multi_modal_config assert self.multi_modal_config
@ -361,10 +362,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
)) ))
self.multi_modal_projector = UltravoxProjector(config) self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config, vllm_config, prefix="language_model")
cache_config,
quant_config,
prefix="language_model")
if config.text_model_id is not None: if config.text_model_id is not None:
self.secondary_weights.append( self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id, DefaultModelLoader.Source(model_or_path=config.text_model_id,

View File

@ -11,11 +11,8 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum, from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend) get_global_forced_attn_backend)
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, from vllm.config import VllmConfig
SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors
@ -236,12 +233,7 @@ class AutoWeightsLoader:
def init_vllm_registered_model( def init_vllm_registered_model(
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig], vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig],
*,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
prefix: str = "", prefix: str = "",
) -> nn.Module: ) -> nn.Module:
""" """
@ -249,18 +241,11 @@ def init_vllm_registered_model(
based on the arguments passed to the outer vLLM model. based on the arguments passed to the outer vLLM model.
""" """
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
import copy
copied_config = copy.deepcopy(vllm_config)
copied_config.model_config.hf_config = hf_config
return build_model( return model_class(vllm_config=copied_config, prefix=prefix)
model_class,
None,
hf_config,
cache_config,
quant_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
scheduler_config=scheduler_config,
prefix=prefix,
)
@overload @overload

View File

@ -27,7 +27,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -225,13 +225,14 @@ class XverseModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
@ -316,13 +317,16 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -61,15 +61,3 @@ def set_compilation_config(config: Optional[CompilationConfig]):
def get_compilation_config() -> Optional[CompilationConfig]: def get_compilation_config() -> Optional[CompilationConfig]:
return _compilation_config return _compilation_config
_vllm_config: Optional[VllmConfig] = None
def set_vllm_config(config: Optional[VllmConfig]):
global _vllm_config
_vllm_config = config
def get_vllm_config() -> Optional[VllmConfig]:
return _vllm_config